GitOrigin-RevId: 45a6c6ad51
tags/v1.0.0-rc1
| @@ -307,7 +307,7 @@ VarNode& VarNode::shape(const TensorShape &shape) { | |||||
| return *this; | return *this; | ||||
| } | } | ||||
| VarNode& VarNode::shape_alloc(const TensorShape &shape) { | |||||
| VarNode& VarNode::shape_alloc(const TensorShape &shape, size_t size_req) { | |||||
| mgb_assert(shape.ndim, "got empty shape in shape_alloc: " | mgb_assert(shape.ndim, "got empty shape in shape_alloc: " | ||||
| "var=%s owner_opr=%s{%s}", cname(), owner_opr()->cname(), | "var=%s owner_opr=%s{%s}", cname(), owner_opr()->cname(), | ||||
| owner_opr()->dyn_typeinfo()->name); | owner_opr()->dyn_typeinfo()->name); | ||||
| @@ -316,7 +316,7 @@ VarNode& VarNode::shape_alloc(const TensorShape &shape) { | |||||
| " NO_SYS_MEM_ALLOC flag; actual var: %s", | " NO_SYS_MEM_ALLOC flag; actual var: %s", | ||||
| cg::dump_var_info({this}).c_str()); | cg::dump_var_info({this}).c_str()); | ||||
| ComputingGraphImpl::downcast(owner_graph()) | ComputingGraphImpl::downcast(owner_graph()) | ||||
| ->var_node_mem_manager().var_alloc_with_shape(this, shape); | |||||
| ->var_node_mem_manager().var_alloc_with_shape(this, shape, size_req); | |||||
| return *this; | return *this; | ||||
| } | } | ||||
| @@ -1239,13 +1239,18 @@ void VarNodeMemManager::make_dev_tensor_from_mem_plan_single( | |||||
| } | } | ||||
| void VarNodeMemManager::var_alloc_with_shape(VarNode* var, | void VarNodeMemManager::var_alloc_with_shape(VarNode* var, | ||||
| const TensorShape& shape) { | |||||
| const TensorShape& shape, | |||||
| size_t size_req) { | |||||
| mgb_assert(var->format().is_default(), | mgb_assert(var->format().is_default(), | ||||
| "dynamic shape is currently only supported for var with " | "dynamic shape is currently only supported for var with " | ||||
| "default format; got %s", | "default format; got %s", | ||||
| var->format().to_string().c_str()); | var->format().to_string().c_str()); | ||||
| var->shape(shape); | var->shape(shape); | ||||
| auto size_req = var->dtype().size(shape.total_nr_elems()); | |||||
| if (size_req != 0) { | |||||
| mgb_assert(var->dtype().size(shape.total_nr_elems()) <= size_req); | |||||
| } else { | |||||
| size_req = var->dtype().size(shape.total_nr_elems()); | |||||
| } | |||||
| auto&& mplan = var->m_mem_plan; | auto&& mplan = var->m_mem_plan; | ||||
| if (!mplan.valid() || mplan.chunk().owner_var != var) | if (!mplan.valid() || mplan.chunk().owner_var != var) | ||||
| @@ -294,7 +294,13 @@ class VarNodeMemManager { | |||||
| void add_layout_constraint_level( | void add_layout_constraint_level( | ||||
| VarNode *dest, LayoutConstraintLevel level); | VarNode *dest, LayoutConstraintLevel level); | ||||
| void var_alloc_with_shape(VarNode *var, const TensorShape &shape); | |||||
| /** | |||||
| * \brief alloc var memory with shape. | |||||
| * | |||||
| * Alloc memory of size_seq if size_req != 0. | |||||
| */ | |||||
| void var_alloc_with_shape(VarNode* var, const TensorShape& shape, | |||||
| size_t size_req = 0); | |||||
| /*! | /*! | ||||
| * \brief initialize mem plan for a single var | * \brief initialize mem plan for a single var | ||||
| @@ -462,8 +462,10 @@ class VarNode final: public GraphNodeBase { | |||||
| * this var must have NO_SYS_MEM_ALLOC flag; if shape does not increase | * this var must have NO_SYS_MEM_ALLOC flag; if shape does not increase | ||||
| * and original tensor storage is valid, it is guaranteed that old data | * and original tensor storage is valid, it is guaranteed that old data | ||||
| * would be retained. | * would be retained. | ||||
| * | |||||
| * \warning Alloc size_req memory if size_req != 0. | |||||
| */ | */ | ||||
| VarNode& shape_alloc(const TensorShape &shape); | |||||
| VarNode& shape_alloc(const TensorShape &shape, size_t size_req = 0); | |||||
| /*! | /*! | ||||
| * \brief directly reset device tensor from another var | * \brief directly reset device tensor from another var | ||||