|
|
@@ -69,32 +69,47 @@ struct GraphInfo { |
|
|
explicit GraphInfo(std::string id) : cell_id(std::move((id))) {} |
|
|
explicit GraphInfo(std::string id) : cell_id(std::move((id))) {} |
|
|
}; |
|
|
}; |
|
|
|
|
|
|
|
|
struct CellInfo { |
|
|
|
|
|
bool is_grad{false}; // Derivative is calculated |
|
|
|
|
|
bool is_custom_bprop{false}; // Custom bprop |
|
|
|
|
|
FuncGraphPtr fg; // Forward graph |
|
|
|
|
|
std::string cell_id; |
|
|
|
|
|
std::string bprop_cell_id; |
|
|
|
|
|
|
|
|
class CellInfo { |
|
|
|
|
|
public: |
|
|
CellInfo() = default; |
|
|
CellInfo() = default; |
|
|
CellInfo(bool isgrad, bool custom_bprop, FuncGraphPtr foward_graph, std::string cellid, std::string bprop_id) |
|
|
|
|
|
: is_grad(isgrad), |
|
|
|
|
|
is_custom_bprop(custom_bprop), |
|
|
|
|
|
|
|
|
CellInfo(bool custom_bprop, bool has_dynamic, FuncGraphPtr foward_graph, std::string cellid, std::string bprop_id) |
|
|
|
|
|
: is_custom_bprop(custom_bprop), |
|
|
|
|
|
is_dynamic(has_dynamic), |
|
|
fg(std::move(foward_graph)), |
|
|
fg(std::move(foward_graph)), |
|
|
cell_id(std::move(cellid)), |
|
|
cell_id(std::move(cellid)), |
|
|
bprop_cell_id(std::move(bprop_id)) {} |
|
|
bprop_cell_id(std::move(bprop_id)) {} |
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
struct TopCellInfo { |
|
|
|
|
|
ResourcePtr resource; |
|
|
|
|
|
FuncGraphPtr df_builder; |
|
|
|
|
|
FuncGraphPtr bg; // Backward graph |
|
|
|
|
|
|
|
|
bool is_grad{false}; // Derivative is calculated |
|
|
|
|
|
bool is_custom_bprop{false}; // Custom bprop |
|
|
|
|
|
bool is_dynamic{false}; // Set by has_dynamic_cell |
|
|
|
|
|
bool is_real_dynamic{false}; // Set by ops order |
|
|
|
|
|
size_t call_times{0}; |
|
|
|
|
|
FuncGraphPtr fg{nullptr}; // Forward graph |
|
|
std::string cell_id; |
|
|
std::string cell_id; |
|
|
bool is_dynamic_cell{false}; |
|
|
|
|
|
|
|
|
std::string bprop_cell_id; |
|
|
|
|
|
std::vector<std::string> cell_ops_info; // All ops info |
|
|
|
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
class TopCellInfo { |
|
|
|
|
|
public: |
|
|
TopCellInfo() = default; |
|
|
TopCellInfo() = default; |
|
|
TopCellInfo(ResourcePtr r, FuncGraphPtr df, FuncGraphPtr backward_graph, std::string cellid) |
|
|
|
|
|
: resource(std::move(r)), df_builder(std::move(df)), bg(std::move(backward_graph)), cell_id(std::move(cellid)) {} |
|
|
|
|
|
|
|
|
TopCellInfo(bool topest, ResourcePtr r, FuncGraphPtr df, std::string cellid) |
|
|
|
|
|
: is_topest(topest), resource(std::move(r)), df_builder(std::move(df)), cell_id(std::move(cellid)) {} |
|
|
|
|
|
|
|
|
|
|
|
bool is_topest{false}; |
|
|
|
|
|
bool do_vm_compiled{false}; |
|
|
|
|
|
ResourcePtr resource{nullptr}; |
|
|
|
|
|
FuncGraphPtr df_builder{nullptr}; |
|
|
|
|
|
FuncGraphPtr bg{nullptr}; // Backward graph |
|
|
|
|
|
std::string cell_id; |
|
|
|
|
|
std::string sens_id; |
|
|
|
|
|
std::string weights_id; |
|
|
}; |
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
using GraphInfoPtr = std::shared_ptr<GraphInfo>; |
|
|
|
|
|
using CellInfoPtr = std::shared_ptr<CellInfo>; |
|
|
|
|
|
using TopCellInfoPtr = std::shared_ptr<TopCellInfo>; |
|
|
|
|
|
|
|
|
class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> { |
|
|
class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> { |
|
|
public: |
|
|
public: |
|
|
static std::shared_ptr<PynativeExecutor> GetInstance() { |
|
|
static std::shared_ptr<PynativeExecutor> GetInstance() { |
|
|
@@ -119,11 +134,12 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> { |
|
|
void NewGraph(const py::object &cell, const py::args &args); |
|
|
void NewGraph(const py::object &cell, const py::args &args); |
|
|
py::object Run(const py::object &cell, const py::tuple &args, const py::object &phase); |
|
|
py::object Run(const py::object &cell, const py::tuple &args, const py::object &phase); |
|
|
py::object CheckGraph(const py::object &cell, const py::args &args); |
|
|
py::object CheckGraph(const py::object &cell, const py::args &args); |
|
|
|
|
|
py::object CheckAlreadyRun(const py::object &cell, const py::args &args); |
|
|
void EndGraph(const py::object &cell, const py::object &out, const py::args &args); |
|
|
void EndGraph(const py::object &cell, const py::object &out, const py::args &args); |
|
|
void GradNet(const GradOperationPtr &grad, const py::object &cell, const py::object &weights, const py::args &args); |
|
|
void GradNet(const GradOperationPtr &grad, const py::object &cell, const py::object &weights, const py::args &args); |
|
|
|
|
|
|
|
|
// Get info |
|
|
// Get info |
|
|
bool GetIsDynamicCell() const { return dynamic_cell_; } |
|
|
|
|
|
|
|
|
bool GetIsDynamicCell() { return CheckRealDynamicCell(top_cell_id_); } |
|
|
// Call by python |
|
|
// Call by python |
|
|
void Clear(const std::string &flag = ""); |
|
|
void Clear(const std::string &flag = ""); |
|
|
void Clean(); |
|
|
void Clean(); |
|
|
@@ -149,7 +165,7 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> { |
|
|
template <typename T> |
|
|
template <typename T> |
|
|
void VectorClear(T *vec, const std::string &cell_id) { |
|
|
void VectorClear(T *vec, const std::string &cell_id) { |
|
|
for (auto it = vec->begin(); it != vec->end();) { |
|
|
for (auto it = vec->begin(); it != vec->end();) { |
|
|
if (it->cell_id.find(cell_id) != std::string::npos) { |
|
|
|
|
|
|
|
|
if ((*it)->cell_id.find(cell_id) != std::string::npos) { |
|
|
it = vec->erase(it); |
|
|
it = vec->erase(it); |
|
|
} else { |
|
|
} else { |
|
|
it++; |
|
|
it++; |
|
|
@@ -201,29 +217,39 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> { |
|
|
void UpdateAbstractAndDeviceAddress(const OpExecInfoPtr &op_exec_info, const py::object &out_real); |
|
|
void UpdateAbstractAndDeviceAddress(const OpExecInfoPtr &op_exec_info, const py::object &out_real); |
|
|
void SaveTensorsInValueNode(const ResourcePtr &resource); |
|
|
void SaveTensorsInValueNode(const ResourcePtr &resource); |
|
|
void SaveAllValueNodeTensors(const FuncGraphPtr &graph); |
|
|
void SaveAllValueNodeTensors(const FuncGraphPtr &graph); |
|
|
void CleanPreMemoryInValueNode(const std::string &cell_id); |
|
|
|
|
|
|
|
|
void CleanPreMemoryInValueNode(); |
|
|
|
|
|
|
|
|
// Construct grad graph |
|
|
// Construct grad graph |
|
|
void PushCurrentGraphToStack(); |
|
|
void PushCurrentGraphToStack(); |
|
|
void PopGraphStack(); |
|
|
void PopGraphStack(); |
|
|
|
|
|
void PushCurrentCellOpInfoToStack(); |
|
|
|
|
|
void PopCurrentCellOpInfoFromStack(); |
|
|
FuncGraphPtr GetDfbuilder(const std::string &cell_id = ""); |
|
|
FuncGraphPtr GetDfbuilder(const std::string &cell_id = ""); |
|
|
ResourcePtr GetResource(const std::string &cell_id = ""); |
|
|
ResourcePtr GetResource(const std::string &cell_id = ""); |
|
|
void AddNestedGradOrder() { ++grad_order_; } |
|
|
void AddNestedGradOrder() { ++grad_order_; } |
|
|
void SubNestedGradOrder(); |
|
|
void SubNestedGradOrder(); |
|
|
bool IsNotNestedGrad() const; |
|
|
|
|
|
|
|
|
bool IsNestedGrad() const; |
|
|
bool IsTopGraph(const std::string &cell_id); |
|
|
bool IsTopGraph(const std::string &cell_id); |
|
|
|
|
|
bool IsTopestGraph(const std::string &cell_id); |
|
|
bool IsBpropGraph(const std::string &cell_id); |
|
|
bool IsBpropGraph(const std::string &cell_id); |
|
|
|
|
|
bool IsFirstGradStep(const std::string &cell_id); |
|
|
bool grad_running() const { return grad_is_running_; } |
|
|
bool grad_running() const { return grad_is_running_; } |
|
|
void set_grad_runing(bool grad_runing) { grad_is_running_ = grad_runing; } |
|
|
void set_grad_runing(bool grad_runing) { grad_is_running_ = grad_runing; } |
|
|
void set_need_replace_forward(bool need_replace_forward) { need_replace_forward_ = need_replace_forward; } |
|
|
void set_need_replace_forward(bool need_replace_forward) { need_replace_forward_ = need_replace_forward; } |
|
|
bool need_construct_graph() { return !graph_stack_.empty() && grad_flag_; } |
|
|
bool need_construct_graph() { return !graph_stack_.empty() && grad_flag_; } |
|
|
bool CheckCellGraph(const std::string &cell_id, bool is_grad = false); |
|
|
bool CheckCellGraph(const std::string &cell_id, bool is_grad = false); |
|
|
|
|
|
bool CheckDynamicCell(const std::string &cell_id); |
|
|
|
|
|
bool CheckRealDynamicCell(const std::string &cell_id); |
|
|
void UpdateCellGraph(const py::object &cell, const FuncGraphPtr &g, const std::string &cell_id, |
|
|
void UpdateCellGraph(const py::object &cell, const FuncGraphPtr &g, const std::string &cell_id, |
|
|
bool need_cloned = false, bool is_grad = false); |
|
|
bool need_cloned = false, bool is_grad = false); |
|
|
|
|
|
void ClearCnodeRes(const AnfNodePtr &node); |
|
|
|
|
|
void UpdateCellDynamic(const std::string &cell_id); |
|
|
|
|
|
bool CheckCellChanged(const std::string &cell_id); |
|
|
|
|
|
void UpdateTopCellCompileInfo(const std::string &cell_id, bool vm_compiled); |
|
|
void ClearResidualRes(const std::string &cell_id); |
|
|
void ClearResidualRes(const std::string &cell_id); |
|
|
void DumpGraphIR(const std::string &filename, const FuncGraphPtr &graph); |
|
|
void DumpGraphIR(const std::string &filename, const FuncGraphPtr &graph); |
|
|
void NewGraphInner(const py::object &cell, const py::args &args); |
|
|
void NewGraphInner(const py::object &cell, const py::args &args); |
|
|
void MakeNewTopGraph(const string &cell_id, const py::args &args, const FuncGraphPtr &g); |
|
|
|
|
|
|
|
|
void MakeNewTopGraph(const string &cell_id, const py::args &args); |
|
|
void EndGraphInner(const py::object &cell, const py::object &out, const py::args &args); |
|
|
void EndGraphInner(const py::object &cell, const py::object &out, const py::args &args); |
|
|
void EndGraphByOutId(const py::object &cell, const std::string &cell_id, const py::object &out, |
|
|
void EndGraphByOutId(const py::object &cell, const std::string &cell_id, const py::object &out, |
|
|
const std::string &out_id, const py::args &args); |
|
|
const std::string &out_id, const py::args &args); |
|
|
@@ -232,38 +258,44 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> { |
|
|
const std::string &cell_id, const py::args &args); |
|
|
const std::string &cell_id, const py::args &args); |
|
|
std::string GetGradCellId(bool has_sens, const py::object &cell, const py::args &args, py::object *forward_args, |
|
|
std::string GetGradCellId(bool has_sens, const py::object &cell, const py::args &args, py::object *forward_args, |
|
|
py::object *sens = nullptr); |
|
|
py::object *sens = nullptr); |
|
|
|
|
|
void ClearDynamicTopRes(const std::string &cell_id); |
|
|
void GradNetInner(const GradOperationPtr &grad, const py::object &cell, const py::object &weights, |
|
|
void GradNetInner(const GradOperationPtr &grad, const py::object &cell, const py::object &weights, |
|
|
const py::args &args); |
|
|
const py::args &args); |
|
|
std::string GetCellId(const py::object &obj, const py::args &args); |
|
|
std::string GetCellId(const py::object &obj, const py::args &args); |
|
|
std::pair<bool, bool> CheckCellChanged(const std::string &cell_id, const py::object &weights, const py::object &sens); |
|
|
|
|
|
|
|
|
std::string GetTensorCellId(const std::string &cell_id); |
|
|
|
|
|
bool CheckGradParamsChanged(const std::string &cell_id, const py::object &weights, const py::object &sens); |
|
|
void SetGradGraphParams(const FuncGraphPtr &df_builder, const ResourcePtr &resource, size_t size); |
|
|
void SetGradGraphParams(const FuncGraphPtr &df_builder, const ResourcePtr &resource, size_t size); |
|
|
void GradGraph(const FuncGraphPtr &g, const GradOperationPtr &grad_op, const std::vector<AnfNodePtr> &weights, |
|
|
void GradGraph(const FuncGraphPtr &g, const GradOperationPtr &grad_op, const std::vector<AnfNodePtr> &weights, |
|
|
size_t arg_size, const std::string &cell_id); |
|
|
size_t arg_size, const std::string &cell_id); |
|
|
std::vector<AnfNodePtr> GetWeightsArgs(const py::object &weights, const FuncGraphPtr &df_builder); |
|
|
std::vector<AnfNodePtr> GetWeightsArgs(const py::object &weights, const FuncGraphPtr &df_builder); |
|
|
abstract::AbstractBasePtrList GetArgsSpec(const py::args &args, const FuncGraphPtr &df_builder); |
|
|
abstract::AbstractBasePtrList GetArgsSpec(const py::args &args, const FuncGraphPtr &df_builder); |
|
|
void UpdateGraphInfoMap(const std::string &cell_id); |
|
|
|
|
|
|
|
|
void ClearUselessRes(const FuncGraphPtr &df_builder, const py::object &cell, const std::string &cell_id); |
|
|
|
|
|
void ReplaceGraphParams(const FuncGraphPtr &df_builder, const FuncGraphPtr &forward_graph, |
|
|
|
|
|
const std::string &cell_id); |
|
|
void SetNestedTopGraph(const py::object &cell, const py::args &args, const std::string &cell_id); |
|
|
void SetNestedTopGraph(const py::object &cell, const py::args &args, const std::string &cell_id); |
|
|
void MakeNestedCnode(const std::string &cell_id, const py::args &args, const ResourcePtr &resource, |
|
|
void MakeNestedCnode(const std::string &cell_id, const py::args &args, const ResourcePtr &resource, |
|
|
const py::object &out, bool has_sens); |
|
|
const py::object &out, bool has_sens); |
|
|
void SetNestedWeightsParam(const FuncGraphPtr &newfg, const std::string &cell_id, std::vector<AnfNodePtr> *inputs); |
|
|
|
|
|
|
|
|
void RecoverGraphParams(const FuncGraphPtr &newfg, const std::string &cell_id, std::vector<AnfNodePtr> *inputs); |
|
|
bool MakeBpropNestedCnode(const py::object &cell, const py::object &out, const std::string &cell_id); |
|
|
bool MakeBpropNestedCnode(const py::object &cell, const py::object &out, const std::string &cell_id); |
|
|
|
|
|
|
|
|
// Hold graph(forward and grad) info |
|
|
// Hold graph(forward and grad) info |
|
|
|
|
|
std::string GetCellOpInfo(); |
|
|
|
|
|
void ReplaceCellOpInfoByCellId(const std::string &cell_id); |
|
|
void SetPyObjInGraphInfoMap(const FuncGraphPtr &g, const std::string &obj) { |
|
|
void SetPyObjInGraphInfoMap(const FuncGraphPtr &g, const std::string &obj) { |
|
|
graph_info_map_[g].objects.push_back(obj); |
|
|
|
|
|
|
|
|
graph_info_map_[g]->objects.push_back(obj); |
|
|
} |
|
|
} |
|
|
void SetTupleArgsToGraphInfoMap(const FuncGraphPtr &g, const py::object &args, const AnfNodePtr &node, |
|
|
void SetTupleArgsToGraphInfoMap(const FuncGraphPtr &g, const py::object &args, const AnfNodePtr &node, |
|
|
bool is_param = false); |
|
|
bool is_param = false); |
|
|
void SetParamNodeMapInGraphInfoMap(const FuncGraphPtr &g, const std::string &id, const ParameterPtr ¶m) { |
|
|
void SetParamNodeMapInGraphInfoMap(const FuncGraphPtr &g, const std::string &id, const ParameterPtr ¶m) { |
|
|
graph_info_map_[g].params[id] = param; |
|
|
|
|
|
|
|
|
graph_info_map_[g]->params[id] = param; |
|
|
} |
|
|
} |
|
|
void SetNodeMapInGraphInfoMap(const FuncGraphPtr &g, const std::string &id, const AnfNodePtr &node, |
|
|
void SetNodeMapInGraphInfoMap(const FuncGraphPtr &g, const std::string &id, const AnfNodePtr &node, |
|
|
int64_t index = -1) { |
|
|
int64_t index = -1) { |
|
|
graph_info_map_[g].node_map[id] = std::make_pair(node, std::vector<int64_t>{index}); |
|
|
|
|
|
|
|
|
graph_info_map_[g]->node_map[id] = std::make_pair(node, std::vector<int64_t>{index}); |
|
|
} |
|
|
} |
|
|
void SetNodeMapInGraphInfoMap(const FuncGraphPtr &g, const std::string &id, const AnfNodePtr &node, |
|
|
void SetNodeMapInGraphInfoMap(const FuncGraphPtr &g, const std::string &id, const AnfNodePtr &node, |
|
|
const std::vector<int64_t> &index) { |
|
|
const std::vector<int64_t> &index) { |
|
|
graph_info_map_[g].node_map[id] = std::make_pair(node, index); |
|
|
|
|
|
|
|
|
graph_info_map_[g]->node_map[id] = std::make_pair(node, index); |
|
|
} |
|
|
} |
|
|
void SetTupleItemArgsToGraphInfoMap(const FuncGraphPtr &g, const py::object &id, const AnfNodePtr &node, |
|
|
void SetTupleItemArgsToGraphInfoMap(const FuncGraphPtr &g, const py::object &id, const AnfNodePtr &node, |
|
|
const std::vector<int64_t> &index_sequence, bool is_param = false); |
|
|
const std::vector<int64_t> &index_sequence, bool is_param = false); |
|
|
@@ -274,7 +306,7 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> { |
|
|
size_t grad_order_{0}; |
|
|
size_t grad_order_{0}; |
|
|
std::string top_cell_id_; |
|
|
std::string top_cell_id_; |
|
|
bool grad_flag_{false}; |
|
|
bool grad_flag_{false}; |
|
|
bool dynamic_cell_{false}; |
|
|
|
|
|
|
|
|
bool has_dynamic_cell_{false}; |
|
|
bool grad_is_running_{false}; |
|
|
bool grad_is_running_{false}; |
|
|
bool need_replace_forward_{true}; |
|
|
bool need_replace_forward_{true}; |
|
|
// The pointer of top python Cell object, which is always the network(inherit class Cell) ran in python test script, |
|
|
// The pointer of top python Cell object, which is always the network(inherit class Cell) ran in python test script, |
|
|
@@ -288,16 +320,15 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> { |
|
|
FuncGraphPtr curr_g_{nullptr}; |
|
|
FuncGraphPtr curr_g_{nullptr}; |
|
|
// Records forwrad graph, the bottom is top graph |
|
|
// Records forwrad graph, the bottom is top graph |
|
|
std::stack<FuncGraphPtr> graph_stack_; |
|
|
std::stack<FuncGraphPtr> graph_stack_; |
|
|
|
|
|
// Records op info of every cell, the bottom is op info of top cell |
|
|
|
|
|
std::stack<std::string> cell_op_info_stack_; |
|
|
|
|
|
|
|
|
// Use vector for keep order |
|
|
// Use vector for keep order |
|
|
std::vector<CellInfo> cell_graph_list_; |
|
|
|
|
|
std::vector<TopCellInfo> top_cell_list_; |
|
|
|
|
|
|
|
|
std::vector<CellInfoPtr> cell_graph_list_; |
|
|
|
|
|
std::vector<TopCellInfoPtr> top_cell_list_; |
|
|
std::unordered_set<std::string> cell_input_args_; |
|
|
std::unordered_set<std::string> cell_input_args_; |
|
|
std::unordered_map<std::string, bool> cell_dynamic_map_; |
|
|
|
|
|
// Record all info for all cells |
|
|
// Record all info for all cells |
|
|
std::unordered_map<FuncGraphPtr, GraphInfo> graph_info_map_; |
|
|
|
|
|
// key: cell_id, value: (send_id, weighs_id), cache for sens and weight change |
|
|
|
|
|
std::unordered_map<std::string, std::pair<std::string, std::string>> cell_sw_map_; |
|
|
|
|
|
|
|
|
OrderedMap<FuncGraphPtr, GraphInfoPtr> graph_info_map_; |
|
|
std::unordered_map<FuncGraphPtr, std::vector<std::pair<ParameterPtr, ParameterPtr>>> replace_weights_map_; |
|
|
std::unordered_map<FuncGraphPtr, std::vector<std::pair<ParameterPtr, ParameterPtr>>> replace_weights_map_; |
|
|
|
|
|
|
|
|
// Used for runop and replace forward result of grad graph |
|
|
// Used for runop and replace forward result of grad graph |
|
|
|