| @@ -126,10 +126,10 @@ int64_t DebugInfo::debug_id() { | |||||
| } | } | ||||
| int64_t DebugInfo::unique_id_through_copy() const { | int64_t DebugInfo::unique_id_through_copy() const { | ||||
| TraceInfoPtr trace_info = const_cast<DebugInfo *>(this)->trace_info(); | |||||
| if (trace_info != nullptr) { | |||||
| if (trace_info->isa<TraceCopy>() && trace_info->debug_info() != nullptr) { | |||||
| return trace_info->debug_info()->unique_id_through_copy(); | |||||
| auto info = trace_info(); | |||||
| if (info != nullptr) { | |||||
| if (info->isa<TraceCopy>() && info->debug_info() != nullptr) { | |||||
| return info->debug_info()->unique_id_through_copy(); | |||||
| } | } | ||||
| } | } | ||||
| return unique_id(); | return unique_id(); | ||||
| @@ -118,7 +118,7 @@ class TraceContext { | |||||
| void set_location(const LocationPtr &loc) { location_ = loc; } | void set_location(const LocationPtr &loc) { location_ = loc; } | ||||
| LocationPtr location() { return location_; } | LocationPtr location() { return location_; } | ||||
| void set_trace_info(const TraceInfoPtr &trace_info) { trace_info_ = trace_info; } | void set_trace_info(const TraceInfoPtr &trace_info) { trace_info_ = trace_info; } | ||||
| TraceInfoPtr trace_info() { return trace_info_; } | |||||
| TraceInfoPtr trace_info() const { return trace_info_; } | |||||
| void set_func_name(const std::string &func_name) { func_name_ = func_name; } | void set_func_name(const std::string &func_name) { func_name_ = func_name; } | ||||
| std::string func_name() { return func_name_; } | std::string func_name() { return func_name_; } | ||||
| }; | }; | ||||
| @@ -139,7 +139,7 @@ class DebugInfo : public Base { | |||||
| std::string get_id() { return std::to_string(debug_id()); } | std::string get_id() { return std::to_string(debug_id()); } | ||||
| void set_trace_info(const TraceInfoPtr &trace_info) { trace_info_ = trace_info; } | void set_trace_info(const TraceInfoPtr &trace_info) { trace_info_ = trace_info; } | ||||
| TraceInfoPtr trace_info() { return trace_info_; } | |||||
| TraceInfoPtr trace_info() const { return trace_info_; } | |||||
| void set_location(const LocationPtr &loc) { location_ = loc; } | void set_location(const LocationPtr &loc) { location_ = loc; } | ||||
| virtual LocationPtr location() { return location_; } | virtual LocationPtr location() { return location_; } | ||||
| std::string name() { return name_; } | std::string name() { return name_; } | ||||
| @@ -57,9 +57,6 @@ class AbstractFunction; | |||||
| using AbstractFunctionPtr = std::shared_ptr<AbstractFunction>; | using AbstractFunctionPtr = std::shared_ptr<AbstractFunction>; | ||||
| } // namespace abstract | } // namespace abstract | ||||
| class FuncGraphManager; | |||||
| using FuncGraphManagerPtr = std::shared_ptr<FuncGraphManager>; | |||||
| // ANF transform class | // ANF transform class | ||||
| // either a primitive or a func_graph | // either a primitive or a func_graph | ||||
| class FuncGraphTransform { | class FuncGraphTransform { | ||||
| @@ -464,7 +464,7 @@ void FuncGraphManager::MoveAllCNodeDropGraph(FuncGraphPtr source, FuncGraphPtr t | |||||
| } | } | ||||
| } | } | ||||
| inline void FuncGraphManager::AddEdge(AnfNodePtr node, int index, AnfNodePtr input) { | |||||
| void FuncGraphManager::AddEdge(AnfNodePtr node, int index, AnfNodePtr input) { | |||||
| auto fg = node->func_graph(); | auto fg = node->func_graph(); | ||||
| if (input->isa<ValueNode>()) { | if (input->isa<ValueNode>()) { | ||||
| fg->AddValueNode(input); | fg->AddValueNode(input); | ||||
| @@ -485,7 +485,7 @@ inline void FuncGraphManager::AddEdge(AnfNodePtr node, int index, AnfNodePtr inp | |||||
| } | } | ||||
| } | } | ||||
| inline void FuncGraphManager::DropEdge(AnfNodePtr node, int index, AnfNodePtr input) { | |||||
| void FuncGraphManager::DropEdge(AnfNodePtr node, int index, AnfNodePtr input) { | |||||
| auto fg = node->func_graph(); | auto fg = node->func_graph(); | ||||
| if (input->isa<ValueNode>()) { | if (input->isa<ValueNode>()) { | ||||
| fg->DropValueNode(input); | fg->DropValueNode(input); | ||||
| @@ -506,7 +506,7 @@ inline void FuncGraphManager::DropEdge(AnfNodePtr node, int index, AnfNodePtr in | |||||
| } | } | ||||
| } | } | ||||
| inline void FuncGraphManager::MoveAllNodes(FuncGraphPtr source, FuncGraphPtr target) { | |||||
| void FuncGraphManager::MoveAllNodes(FuncGraphPtr source, FuncGraphPtr target) { | |||||
| target->CopyNodes(source); | target->CopyNodes(source); | ||||
| target->CopyValueNodes(source); | target->CopyValueNodes(source); | ||||
| target->CopyFuncGraphCNodesIndex(source); | target->CopyFuncGraphCNodesIndex(source); | ||||
| @@ -28,9 +28,7 @@ | |||||
| #include "pipeline/static_analysis/abstract_value.h" | #include "pipeline/static_analysis/abstract_value.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace tensor { | namespace tensor { | ||||
| void DataBuf2Contiguous(const py::array &src, py::array *const dest) { | void DataBuf2Contiguous(const py::array &src, py::array *const dest) { | ||||
| if (dest == nullptr) { | if (dest == nullptr) { | ||||
| MS_LOG(EXCEPTION) << "Failed to copy data to a contiguous buffer as dest is nullptr!"; | MS_LOG(EXCEPTION) << "Failed to copy data to a contiguous buffer as dest is nullptr!"; | ||||
| @@ -493,6 +491,5 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) { | |||||
| .def("dtype", &MetaTensor::Dtype, "Get the MetaTensor's dtype.") | .def("dtype", &MetaTensor::Dtype, "Get the MetaTensor's dtype.") | ||||
| .def("shape", &MetaTensor::shape, "Get the MetaTensor's shape."); | .def("shape", &MetaTensor::shape, "Get the MetaTensor's shape."); | ||||
| })); | })); | ||||
| } // namespace tensor | } // namespace tensor | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -34,9 +34,7 @@ namespace py = pybind11; | |||||
| using float16 = Eigen::half; | using float16 = Eigen::half; | ||||
| namespace pybind11 { | namespace pybind11 { | ||||
| namespace detail { | namespace detail { | ||||
| // Similar to enums in `pybind11/numpy.h`. Determined by doing: | // Similar to enums in `pybind11/numpy.h`. Determined by doing: | ||||
| // python3 -c 'import numpy as np; print(np.dtype(np.float16).num)' | // python3 -c 'import numpy as np; print(np.dtype(np.float16).num)' | ||||
| constexpr int NPY_FLOAT16 = 23; | constexpr int NPY_FLOAT16 = 23; | ||||
| @@ -85,7 +83,6 @@ template <> | |||||
| struct type_caster<float16> : public npy_scalar_caster<float16> { | struct type_caster<float16> : public npy_scalar_caster<float16> { | ||||
| static constexpr auto name = "float16"; | static constexpr auto name = "float16"; | ||||
| }; | }; | ||||
| } // namespace detail | } // namespace detail | ||||
| } // namespace pybind11 | } // namespace pybind11 | ||||
| @@ -96,7 +93,6 @@ using DeviceAddressPtr = std::shared_ptr<mindspore::device::DeviceAddress>; | |||||
| // mindspore namespace is the top level namespace of Mindsporeession project. | // mindspore namespace is the top level namespace of Mindsporeession project. | ||||
| // Other namespace should be a sub namespace of mindspore namespace in the ME project. | // Other namespace should be a sub namespace of mindspore namespace in the ME project. | ||||
| namespace mindspore { | namespace mindspore { | ||||
| // brief mindspore::tensor namespace | // brief mindspore::tensor namespace | ||||
| // | // | ||||
| // A sub namespace in ME to support tensor related definition. | // A sub namespace in ME to support tensor related definition. | ||||
| @@ -273,7 +269,6 @@ class Tensor : public MetaTensor { | |||||
| using TensorPtr = std::shared_ptr<Tensor>; | using TensorPtr = std::shared_ptr<Tensor>; | ||||
| using TensorPtrList = std::vector<std::shared_ptr<Tensor>>; | using TensorPtrList = std::vector<std::shared_ptr<Tensor>>; | ||||
| } // namespace tensor | } // namespace tensor | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -39,6 +39,5 @@ class ParamValueMinnie : public ParamValue { | |||||
| }; | }; | ||||
| using ParamValueMinniePtr = std::shared_ptr<ParamValueMinnie>; | using ParamValueMinniePtr = std::shared_ptr<ParamValueMinnie>; | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CCSRC_MINNIE_PARAM_VALUE_MINNIE_H_ | #endif // MINDSPORE_CCSRC_MINNIE_PARAM_VALUE_MINNIE_H_ | ||||
| @@ -70,7 +70,6 @@ class TensorMinnie : public MetaTensor { | |||||
| }; | }; | ||||
| using TensorMinniePtr = std::shared_ptr<TensorMinnie>; | using TensorMinniePtr = std::shared_ptr<TensorMinnie>; | ||||
| } // namespace tensor | } // namespace tensor | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -39,7 +39,6 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| // namespace to support composite operators definition | // namespace to support composite operators definition | ||||
| namespace prim { | namespace prim { | ||||
| MultitypeFuncGraph::MultitypeFuncGraph(const std::string &name) : MetaFuncGraph(name) { | MultitypeFuncGraph::MultitypeFuncGraph(const std::string &name) : MetaFuncGraph(name) { | ||||
| fn_cache_.clear(); | fn_cache_.clear(); | ||||
| signatures_ = std::vector<Signature>({// def multitype(*args:ref): | signatures_ = std::vector<Signature>({// def multitype(*args:ref): | ||||
| @@ -148,6 +147,5 @@ REGISTER_PYBIND_DEFINE(MultitypeFuncGraph_, ([](const py::module *m) { | |||||
| .def(py::init<std::string &>()) | .def(py::init<std::string &>()) | ||||
| .def("register_fn", &MultitypeFuncGraph::PyRegister); | .def("register_fn", &MultitypeFuncGraph::PyRegister); | ||||
| })); | })); | ||||
| } // namespace prim | } // namespace prim | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -34,7 +34,6 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| // namespace to support composite operators definition | // namespace to support composite operators definition | ||||
| namespace prim { | namespace prim { | ||||
| class MultitypeFuncGraph : public MetaFuncGraph { | class MultitypeFuncGraph : public MetaFuncGraph { | ||||
| public: | public: | ||||
| explicit MultitypeFuncGraph(const std::string &name); | explicit MultitypeFuncGraph(const std::string &name); | ||||
| @@ -59,7 +58,6 @@ class MultitypeFuncGraph : public MetaFuncGraph { | |||||
| std::unordered_map<TypePtrList, py::function, TypeListHasher, TypeListEqual> fn_cache_py_; | std::unordered_map<TypePtrList, py::function, TypeListHasher, TypeListEqual> fn_cache_py_; | ||||
| }; | }; | ||||
| using MultitypeFuncGraphPtr = std::shared_ptr<MultitypeFuncGraph>; | using MultitypeFuncGraphPtr = std::shared_ptr<MultitypeFuncGraph>; | ||||
| } // namespace prim | } // namespace prim | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -88,7 +88,7 @@ AnfNodePtr Substitution::operator()(const OptimizerPtr &optimizer, const AnfNode | |||||
| return result; | return result; | ||||
| } | } | ||||
| inline bool isTraversable(const AnfNodePtr &node) { | |||||
| static bool isTraversable(const AnfNodePtr &node) { | |||||
| if (node == nullptr) { | if (node == nullptr) { | ||||
| return false; | return false; | ||||
| } | } | ||||
| @@ -41,7 +41,7 @@ py::function GetComputeFunction(std::string name) { | |||||
| if (!py::hasattr(mod, common::SafeCStr(name))) { | if (!py::hasattr(mod, common::SafeCStr(name))) { | ||||
| PyErr_SetString(PyExc_NotImplementedError, common::SafeCStr(name)); | PyErr_SetString(PyExc_NotImplementedError, common::SafeCStr(name)); | ||||
| // If raise AttributeError, user can't understand. This case need raise NotImplementedError. | // If raise AttributeError, user can't understand. This case need raise NotImplementedError. | ||||
| throw py::error_already_set(); | |||||
| throw(py::error_already_set()); | |||||
| } | } | ||||
| py::object fn = mod.attr(common::SafeCStr(name)); | py::object fn = mod.attr(common::SafeCStr(name)); | ||||
| return fn; | return fn; | ||||
| @@ -619,7 +619,7 @@ void FinalVM::SyncData(const py::object &arg) { | |||||
| BaseRef FinalVM::RunHook(const PrimitivePtr &prim, const VectorRef &args) { | BaseRef FinalVM::RunHook(const PrimitivePtr &prim, const VectorRef &args) { | ||||
| MS_LOG(DEBUG) << "input for operation:"; | MS_LOG(DEBUG) << "input for operation:"; | ||||
| std::size_t args_size = args.size(); | std::size_t args_size = args.size(); | ||||
| py::tuple py_args = py::tuple(args_size); | |||||
| auto py_args = py::tuple(args_size); | |||||
| size_t i = 0; | size_t i = 0; | ||||
| for (auto &arg : args) { | for (auto &arg : args) { | ||||
| py_args[i] = BaseRefToPyData(arg); | py_args[i] = BaseRefToPyData(arg); | ||||
| @@ -643,7 +643,7 @@ BaseRef FinalVM::RunHook(const PrimitivePtr &prim, const VectorRef &args) { | |||||
| std::string cell_id = GetValue<std::string>(prim->GetAttr("cell_id")); | std::string cell_id = GetValue<std::string>(prim->GetAttr("cell_id")); | ||||
| if (_hook_grad.find(cell_id) != _hook_grad.end()) { | if (_hook_grad.find(cell_id) != _hook_grad.end()) { | ||||
| std::size_t hook_args_size = 3; | std::size_t hook_args_size = 3; | ||||
| py::tuple hook_args = py::tuple(hook_args_size); | |||||
| auto hook_args = py::tuple(hook_args_size); | |||||
| hook_args[0] = cell_id; | hook_args[0] = cell_id; | ||||
| hook_args[1] = py::make_tuple(_hook_grad[cell_id]); | hook_args[1] = py::make_tuple(_hook_grad[cell_id]); | ||||
| hook_args[2] = py::make_tuple(py_args[2]); | hook_args[2] = py::make_tuple(py_args[2]); | ||||