GitOrigin-RevId: 5b036c2c5a
tags/v1.9.0
| @@ -111,7 +111,6 @@ def test_xornet_trace_dump(): | |||
| _, loss = val_fun(data, label) | |||
| loss = loss.numpy() | |||
| val_loss.append((step, loss)) | |||
| print("Step: {} loss={}".format(step, loss)) | |||
| opt.step() | |||
| test_data = np.array( | |||
| @@ -89,8 +89,7 @@ def test_subgraph(device, batch_size, channels, use_trace, symbolic, gopt_level, | |||
| return megengine.tensor(np.random.random(shape), dtype=dtype, device=device) | |||
| # skip this test because could not do several reduce sequentially with opr cache | |||
| if device == "cpux": | |||
| return | |||
| return | |||
| # test shape change | |||
| for image_shape in [(223, 223), (10, 20)]: | |||
| @@ -718,7 +718,6 @@ def test_assert_equal(): | |||
| inp2 = g.make_h2d(dtype=np.float32, device="xpux") | |||
| op = builtin.AssertEqual(maxerr=1e-5) | |||
| out = G.apply_normal_varnode(op, inp1._node, inp2._node)[0] | |||
| print(out) | |||
| g.compile(out) | |||
| file = io.BytesIO() | |||
| out_model = G.dump_graph([out]) | |||
| @@ -51,7 +51,6 @@ def test_profiler(format, trace_mode): | |||
| with Profiler(profile_prefix, format=format): | |||
| infer() | |||
| print(profile_path) | |||
| assert os.path.exists(profile_path), "profiling results not found" | |||
| if format == "chrome_timeline.json": | |||
| @@ -49,6 +49,7 @@ struct ApplyOp { | |||
| std::shared_ptr<OpDef> op; | |||
| SmallVector<TensorInfo*> inputs; | |||
| SmallVector<TensorInfo*> outputs; | |||
| bool validated = false; | |||
| template <typename TFunctor> | |||
| void get_props(TFunctor&& functor) const { | |||
| @@ -280,7 +280,8 @@ void ChannelImpl::dispatch_default_cpu( | |||
| input_tensors.push_back(Tensor::make( | |||
| input_tensornd, HostTensorND::make_proxy(input_tensornd))); | |||
| } | |||
| auto output_tensors = OpDef::apply_on_physical_tensor(*op, input_tensors); | |||
| auto output_tensors = OpDef::apply_on_physical_tensor( | |||
| *op, input_tensors, output_descs, validated); | |||
| for (size_t i = 0; i < output_tensors.size(); ++i) { | |||
| output_tensornds[i].copy_from_fixlayout(output_tensors[i]->dev_tensor()); | |||
| } | |||
| @@ -324,6 +325,7 @@ void ChannelImpl::dispatch_kernel( | |||
| MGB_RECORD_EVENT(ShapeInferEvent, validated); | |||
| ApplyOp cmd{Profiler::next_id(), std::move(op)}; | |||
| cmd.validated = validated; | |||
| cmd.inputs = std::move(input_infos); | |||
| for (int i = 0; i < output_descs.size(); ++i) { | |||
| auto&& desc = output_descs[i]; | |||
| @@ -703,14 +705,16 @@ void ChannelImpl::do_apply_op(const ApplyOp& cmd, std::string reason) { | |||
| auto_evict(0); | |||
| } | |||
| auto apply_on_physical_tensor = | |||
| [&](auto&& self, const OpDef& def, | |||
| SmallVector<TensorPtr> inputs) -> SmallVector<TensorPtr> { | |||
| [&](auto&& self, const OpDef& def, SmallVector<TensorPtr> inputs, | |||
| SmallVector<LogicalTensorDesc>& output_descs, | |||
| const bool& validated) -> SmallVector<TensorPtr> { | |||
| auto apply_functor = [&](std::shared_ptr<OpDef> op, | |||
| SmallVector<TensorPtr> inputs, | |||
| size_t nr_outputs) -> SmallVector<TensorPtr> { | |||
| auto opname = op->trait()->make_name(*op); | |||
| imperative_log_profile_begin(opname.c_str()); | |||
| auto outputs = self(self, *op, inputs); | |||
| // do not use infered output_desc in subgraph | |||
| auto outputs = self(self, *op, inputs, output_descs, false); | |||
| imperative_log_profile_end(opname.c_str()); | |||
| return outputs; | |||
| }; | |||
| @@ -726,7 +730,7 @@ void ChannelImpl::do_apply_op(const ApplyOp& cmd, std::string reason) { | |||
| inputs, apply_functor, const_functor); | |||
| return outputs; | |||
| } | |||
| return OpDef::apply_on_physical_tensor(def, inputs); | |||
| return OpDef::apply_on_physical_tensor(def, inputs, output_descs, validated); | |||
| }; | |||
| MGB_RECORD_EVENT(OpExecuteEvent, apply_id, {}, reason); | |||
| // Begin profiling operator | |||
| @@ -757,8 +761,13 @@ void ChannelImpl::do_apply_op(const ApplyOp& cmd, std::string reason) { | |||
| Timer::record_device(device)); | |||
| } | |||
| // Apply op | |||
| SmallVector<LogicalTensorDesc> output_descs; | |||
| for (auto i : cmd.outputs) { | |||
| output_descs.push_back(i->desc); | |||
| } | |||
| // Here std::move is REQUIRED for removing duplicated references. | |||
| auto outputs = apply_on_physical_tensor(apply_on_physical_tensor, *cmd.op, inputs); | |||
| auto outputs = apply_on_physical_tensor( | |||
| apply_on_physical_tensor, *cmd.op, inputs, output_descs, cmd.validated); | |||
| // After execute | |||
| for (auto&& [device, kernel_id] : kernels) { | |||
| MGB_RECORD_EVENT_IF( | |||
| @@ -39,8 +39,10 @@ DispatchMode OpDef::decide_dispatch_mode( | |||
| } | |||
| SmallVector<TensorPtr> OpDef::apply_on_physical_tensor( | |||
| const OpDef& def, SmallVector<TensorPtr> inputs) { | |||
| return def.trait()->apply_on_physical_tensor(def, std::move(inputs)); | |||
| const OpDef& def, SmallVector<TensorPtr> inputs, | |||
| SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) { | |||
| return def.trait()->apply_on_physical_tensor( | |||
| def, std::move(inputs), output_descs, validated); | |||
| } | |||
| void OpDef::apply_on_device_tensornd( | |||
| const OpDef& def, const SmallVector<DeviceTensorND>& inputs, | |||
| @@ -51,7 +51,6 @@ bool valid_broadcast(const TensorShape& src_shape, const TensorShape& tar_shape) | |||
| std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||
| const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) { | |||
| def.cast_final_safe<Broadcast>(); | |||
| size_t nr_inp = inputs.size(); | |||
| mgb_assert(nr_inp == 2, "Broadcast expects 2 inputs; got %lu actually", nr_inp); | |||
| auto&& src = inputs[0]; | |||
| @@ -82,11 +81,16 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||
| } | |||
| SmallVector<TensorPtr> apply_on_physical_tensor( | |||
| const OpDef& def, const SmallVector<TensorPtr>& inputs) { | |||
| const OpDef& def, const SmallVector<TensorPtr>& inputs, | |||
| SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) { | |||
| auto& input = inputs[0]; | |||
| TensorShape target_shape; | |||
| cg::copy_tensor_value_to_shape( | |||
| target_shape, inputs[1]->get_value().proxy_to_default_cpu()); | |||
| if (validated) { | |||
| target_shape = output_descs[0].layout; | |||
| } else { | |||
| cg::copy_tensor_value_to_shape( | |||
| target_shape, inputs[1]->get_value().proxy_to_default_cpu()); | |||
| } | |||
| TensorPtr output = Tensor::make( | |||
| TensorLayout(target_shape, input->dtype()), input->comp_node()); | |||
| if (output->layout().is_empty()) { | |||
| @@ -171,7 +175,8 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||
| } | |||
| SmallVector<TensorPtr> apply_on_physical_tensor( | |||
| const OpDef& def, const SmallVector<TensorPtr>& inputs) { | |||
| const OpDef& def, const SmallVector<TensorPtr>& inputs, | |||
| SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) { | |||
| auto&& op_def = def.cast_final_safe<Reshape>(); | |||
| size_t nr_inp = inputs.size(); | |||
| mgb_assert(nr_inp == 2, "Reshape expects 2 inputs; got %lu actually", nr_inp); | |||
| @@ -179,6 +184,10 @@ SmallVector<TensorPtr> apply_on_physical_tensor( | |||
| auto&& tshp_nd = inputs[1]; | |||
| auto slayout = src->layout(); | |||
| if (validated) { | |||
| return {Tensor::make(src->blob(), 0, output_descs[0].layout)}; | |||
| } | |||
| TensorShape tshp; | |||
| cg::copy_tensor_value_to_shape(tshp, tshp_nd->get_value().proxy_to_default_cpu()); | |||
| if (op_def.axis != opr::Reshape::Param::INVALID_AXIS) { | |||
| @@ -186,9 +195,7 @@ SmallVector<TensorPtr> apply_on_physical_tensor( | |||
| tshp[op_def.axis] = 1; | |||
| tshp[op_def.axis] = src->layout().total_nr_elems() / tshp.total_nr_elems(); | |||
| } | |||
| TensorLayout tlayout = slayout.reshape(tshp); | |||
| // memory forward | |||
| return {Tensor::make(src->blob(), 0, tlayout)}; | |||
| return {Tensor::make(src->blob(), 0, slayout.reshape(tshp))}; | |||
| } | |||
| OP_TRAIT_REG(Reshape, Reshape) | |||
| @@ -33,9 +33,8 @@ cg::OperatorNodeBase* apply_on_var_node(const OpDef& def, const VarNodeArray& in | |||
| } | |||
| SmallVector<TensorPtr> apply_on_physical_tensor( | |||
| const OpDef& def, const SmallVector<TensorPtr>& inputs) { | |||
| auto&& opr = def.cast_final_safe<CondTake>(); | |||
| mgb_assert(opr.same_type<CondTake>()); | |||
| const OpDef& def, const SmallVector<TensorPtr>& inputs, | |||
| SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) { | |||
| mgb_assert(inputs.size() == 2, "CondTake take 2 inputs, got %lu", inputs.size()); | |||
| auto&& inp = inputs[0]; | |||
| @@ -196,16 +196,14 @@ void apply_on_device_tensornd( | |||
| } | |||
| SmallVector<TensorPtr> apply_on_physical_tensor( | |||
| const OpDef& def, const SmallVector<TensorPtr>& inputs) { | |||
| auto&& op = static_cast<const CustomOpDef&>(def); | |||
| auto [output_descs, success] = op.infer_output_attrs(inputs); | |||
| mgb_assert(success == true, "infer output attributes fall\n"); | |||
| const OpDef& def, const SmallVector<TensorPtr>& inputs, | |||
| SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) { | |||
| mgb_assert(validated == true, "infer output attributes fall\n"); | |||
| SmallVector<TensorPtr> outputs(output_descs.size()); | |||
| for (size_t i = 0; i < outputs.size(); ++i) { | |||
| auto& output = outputs[i]; | |||
| auto& output_desc = output_descs[i]; | |||
| output = Tensor::make(output_desc.layout, output_desc.comp_node); | |||
| output = Tensor::make(output_descs[i].layout, output_descs[i].comp_node); | |||
| } | |||
| SmallVector<DeviceTensorND> inp_tensornds(inputs.size()); | |||
| @@ -112,17 +112,14 @@ void apply_on_device_tensornd( | |||
| } | |||
| SmallVector<TensorPtr> apply_on_physical_tensor( | |||
| const OpDef& def, const SmallVector<TensorPtr>& inputs) { | |||
| auto&& op_def = def.cast_final_safe<Elemwise>(); | |||
| const OpDef& def, const SmallVector<TensorPtr>& inputs, | |||
| SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) { | |||
| SmallVector<DeviceTensorND> inp_tensornds(inputs.size()); | |||
| TensorShapeArray inp_shapes(inputs.size()); | |||
| for (unsigned i = 0; i < inputs.size(); ++i) { | |||
| inp_tensornds[i] = inputs[i]->dev_tensor(); | |||
| inp_shapes[i] = inputs[i]->layout(); | |||
| } | |||
| TensorShape shape = opr::Elemwise::get_output_var_shape(op_def.mode, inp_shapes); | |||
| DeviceTensorND out = BlobManager::inst()->alloc_workspace_with_defrag( | |||
| inp_tensornds[0].comp_node(), {shape, inp_tensornds[0].layout().dtype}); | |||
| inp_tensornds[0].comp_node(), output_descs[0].layout); | |||
| SmallVector<DeviceTensorND> oup_tensornds = {out}; | |||
| apply_on_device_tensornd(def, inp_tensornds, &oup_tensornds); | |||
| return {Tensor::make(oup_tensornds[0])}; | |||
| @@ -221,7 +218,8 @@ cg::OperatorNodeBase* apply_inplace_add_on_var_node( | |||
| } | |||
| SmallVector<TensorPtr> apply_inplace_add_on_physical_tensor( | |||
| const OpDef& def, const SmallVector<TensorPtr>& inputs) { | |||
| const OpDef& def, const SmallVector<TensorPtr>& inputs, | |||
| SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) { | |||
| mgb_assert( | |||
| inputs[0]->blob().use_count() == 1 && inputs[0]->blob()->storage().unique(), | |||
| "This inplace modification may change the elements of other tensors. " | |||
| @@ -24,7 +24,8 @@ SymbolVarArray apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
| } | |||
| SmallVector<TensorPtr> apply_on_physical_tensor( | |||
| const OpDef& def, const SmallVector<TensorPtr>& inputs) { | |||
| const OpDef& def, const SmallVector<TensorPtr>& inputs, | |||
| SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) { | |||
| size_t size = inputs.size(); | |||
| auto&& op = def.cast_final_safe<CheckNonFinite>(); | |||
| SmallVector<TensorPtr> outputs(size + 1); | |||
| @@ -63,18 +64,6 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||
| dests[size].layout = TensorLayout(TensorShape({1}), dtype::Int32()); | |||
| return {dests, true}; | |||
| } | |||
| SmallVector<LogicalTensorDesc> infer_output_attrs( | |||
| const OpDef& def, const SmallVector<TensorPtr>& inputs) { | |||
| size_t size = inputs.size(); | |||
| SmallVector<LogicalTensorDesc> dests(size + 1); | |||
| for (size_t i = 0; i < size; ++i) { | |||
| dests[i].comp_node = inputs[i]->comp_node(); | |||
| dests[i].layout = inputs[i]->layout(); | |||
| } | |||
| dests[size].comp_node = inputs[0]->comp_node(); | |||
| dests[size].layout = TensorLayout(TensorShape({1}), dtype::Int32()); | |||
| return dests; | |||
| } | |||
| OP_TRAIT_REG(CheckNonFinite, CheckNonFinite) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| @@ -51,11 +51,13 @@ bool memory_forward_success(const OpDef& def, SmallVector<TensorPtr> inputs) { | |||
| } | |||
| SmallVector<TensorPtr> apply_on_physical_tensor( | |||
| const OpDef& def, const SmallVector<TensorPtr>& inputs) { | |||
| const OpDef& def, const SmallVector<TensorPtr>& inputs, | |||
| SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) { | |||
| if (memory_forward_success(def, inputs)) { | |||
| return {Tensor::make(inputs[0]->blob(), 0, inputs[0]->layout())}; | |||
| } | |||
| return proxy_graph_detail::apply_on_physical_tensor(def, inputs); | |||
| return proxy_graph_detail::apply_on_physical_tensor( | |||
| def, inputs, output_descs, validated); | |||
| } | |||
| std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||
| @@ -419,8 +419,7 @@ _INST_RNG_MAKER(2) | |||
| template <typename Op> | |||
| void exec( | |||
| const OpDef& op, const SmallVector<TensorPtr>& inputs, | |||
| const SmallVector<TensorPtr>& outputs, | |||
| const SmallVector<TensorPtr>& workspace) { | |||
| const SmallVector<TensorPtr>& outputs) { | |||
| auto&& rng = op.cast_final_safe<Op>(); | |||
| auto dest = outputs[0]; | |||
| @@ -451,82 +450,68 @@ void exec( | |||
| } | |||
| template <typename Op> | |||
| SmallVector<LogicalTensorDesc> infer_output_attrs( | |||
| SmallVector<CompNode> infer_output_cns( | |||
| const OpDef& op, const SmallVector<TensorPtr>& inputs) { | |||
| LogicalTensorDesc dest; | |||
| CompNode cn; | |||
| auto&& rng = op.cast_final_safe<Op>(); | |||
| auto handle = rng.handle; | |||
| if (handle) { | |||
| dest.comp_node = RNGDnnOpManager::get_comp_node(handle); | |||
| cn = RNGDnnOpManager::get_comp_node(handle); | |||
| } else { | |||
| dest.comp_node = inputs[0]->comp_node(); | |||
| cn = inputs[0]->comp_node(); | |||
| } | |||
| constexpr bool rng_with_shape = OpMeth<Op>::DnnOp::NR_INPUTS == 0; | |||
| if (!rng_with_shape) { | |||
| for (int i = 0; i < inputs.size(); ++i) { | |||
| mgb_assert( | |||
| inputs[i]->comp_node() == dest.comp_node, | |||
| inputs[i]->comp_node() == cn, | |||
| "%s expects the device of inputs[%d] to be same as the device of " | |||
| "handle; " | |||
| "got %s and %s actually", | |||
| rng.dyn_typeinfo()->name, i, | |||
| inputs[i]->comp_node().to_string().c_str(), | |||
| dest.comp_node.to_string().c_str()); | |||
| inputs[i]->comp_node().to_string().c_str(), cn.to_string().c_str()); | |||
| } | |||
| } | |||
| dest.layout = _InferLayout<rng_with_shape>::do_infer(inputs[0], rng); | |||
| return {dest}; | |||
| return {cn}; | |||
| } | |||
| template <> | |||
| SmallVector<LogicalTensorDesc> infer_output_attrs<ShuffleRNG>( | |||
| SmallVector<CompNode> infer_output_cns<ShuffleRNG>( | |||
| const OpDef& op, const SmallVector<TensorPtr>& inputs) { | |||
| SmallVector<LogicalTensorDesc> dests(2); | |||
| SmallVector<CompNode> cns(2); | |||
| auto&& rng = op.cast_final_safe<ShuffleRNG>(); | |||
| auto handle = rng.handle; | |||
| if (handle) { | |||
| dests[0].comp_node = RNGDnnOpManager::get_comp_node(handle); | |||
| dests[1].comp_node = RNGDnnOpManager::get_comp_node(handle); | |||
| cns[0] = RNGDnnOpManager::get_comp_node(handle); | |||
| cns[1] = RNGDnnOpManager::get_comp_node(handle); | |||
| } else { | |||
| dests[0].comp_node = inputs[0]->comp_node(); | |||
| dests[1].comp_node = inputs[0]->comp_node(); | |||
| cns[0] = inputs[0]->comp_node(); | |||
| cns[1] = inputs[0]->comp_node(); | |||
| } | |||
| dests[0].layout = TensorLayout(inputs[0]->layout()); | |||
| dests[0].layout.dtype = inputs[0]->layout().dtype; | |||
| dests[1].layout = | |||
| TensorLayout(TensorShape({inputs[0]->layout()[0]}), dtype::Int32()); | |||
| return dests; | |||
| return cns; | |||
| } | |||
| template <> | |||
| SmallVector<LogicalTensorDesc> infer_output_attrs<Dropout>( | |||
| SmallVector<CompNode> infer_output_cns<Dropout>( | |||
| const OpDef& op, const SmallVector<TensorPtr>& inputs) { | |||
| SmallVector<LogicalTensorDesc> dests(2); | |||
| SmallVector<CompNode> cns(2); | |||
| auto&& cn = inputs[0]->comp_node(); | |||
| dests[0].comp_node = cn; | |||
| dests[0].layout = TensorLayout(inputs[0]->layout()); | |||
| dests[0].layout.dtype = inputs[0]->layout().dtype; | |||
| auto get_mask_size = [&]() -> size_t { | |||
| auto dnn_handle = MegDNNHandle::get(CompNodeEnv::from_comp_node(cn)).handle(); | |||
| return dnn_handle->create_operator<megdnn::Dropout>()->get_mask_size_in_bytes( | |||
| inputs[0]->layout()); | |||
| }; | |||
| dests[1].comp_node = cn; | |||
| dests[1].layout = TensorLayout(TensorShape({get_mask_size()}), dtype::Byte()); | |||
| return dests; | |||
| cns[0] = cn; | |||
| cns[1] = cn; | |||
| return cns; | |||
| } | |||
| template <typename Op> | |||
| SmallVector<TensorPtr> apply_on_physical_tensor( | |||
| const OpDef& def, const SmallVector<TensorPtr>& inputs) { | |||
| const OpDef& def, const SmallVector<TensorPtr>& inputs, | |||
| SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) { | |||
| SmallVector<TensorPtr> outputs; | |||
| SmallVector<LogicalTensorDesc> desc = infer_output_attrs<Op>(def, inputs); | |||
| for (auto&& i : desc) { | |||
| outputs.push_back(Tensor::make(i.layout, i.comp_node)); | |||
| SmallVector<CompNode> cns = infer_output_cns<Op>(def, inputs); | |||
| for (size_t i = 0; i < cns.size(); i++) { | |||
| outputs.push_back(Tensor::make(output_descs[i].layout, cns[i])); | |||
| } | |||
| exec<Op>(def, inputs, outputs, {}); | |||
| exec<Op>(def, inputs, outputs); | |||
| return outputs; | |||
| } | |||
| @@ -99,7 +99,8 @@ HostTensorND get_var_shape_host_tensor( | |||
| } | |||
| SmallVector<TensorPtr> apply_on_physical_tensor( | |||
| const OpDef& def, const SmallVector<TensorPtr>& inputs) { | |||
| const OpDef& def, const SmallVector<TensorPtr>& inputs, | |||
| SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) { | |||
| return {Tensor::make(std::move(get_var_shape_host_tensor(def, inputs)))}; | |||
| } | |||
| @@ -180,7 +181,8 @@ cg::OperatorNodeBase* param_pack_split_apply_on_var_node( | |||
| } | |||
| SmallVector<TensorPtr> param_pack_split_apply_on_physical_tensor( | |||
| const OpDef& def, const SmallVector<TensorPtr>& inputs) { | |||
| const OpDef& def, const SmallVector<TensorPtr>& inputs, | |||
| SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) { | |||
| auto&& param = def.cast_final_safe<ParamPackSplit>(); | |||
| mgb_assert( | |||
| inputs.size() == 1, "ParamPackSplit take 1 input, got %lu", inputs.size()); | |||
| @@ -217,7 +219,8 @@ cg::OperatorNodeBase* param_pack_concat_apply_on_var_node( | |||
| } | |||
| SmallVector<TensorPtr> param_pack_concat_apply_on_physical_tensor( | |||
| const OpDef& def, const SmallVector<TensorPtr>& inputs) { | |||
| const OpDef& def, const SmallVector<TensorPtr>& inputs, | |||
| SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) { | |||
| def.cast_final_safe<ParamPackConcat>(); | |||
| mgb_assert(inputs.size() > 1, "param_pack should have at least one input"); | |||
| auto comp_node = inputs.front()->comp_node(); | |||
| @@ -62,25 +62,10 @@ OP_TRAIT_REG(FastpathCopy, FastpathCopy) | |||
| namespace { | |||
| namespace shape_infer { | |||
| auto apply_on_physical_tensor(const OpDef& def, const SmallVector<TensorPtr>& inputs) { | |||
| auto& op = def.cast_final_safe<ShapeInfer>(); | |||
| size_t nr_inputs = inputs.size(); | |||
| mgb_assert(nr_inputs > 0, "no inputs for ShapeInfer"); | |||
| SmallVector<LogicalTensorDesc> input_descs; | |||
| for (size_t i = 0; i < nr_inputs; ++i) { | |||
| auto input = inputs[i]->get_value(); | |||
| TensorLayout layout; | |||
| layout.ndim = input.shape(0); | |||
| for (size_t i = 0; i < layout.ndim; ++i) { | |||
| layout[i] = input.ptr<int32_t>()[i]; | |||
| } | |||
| layout.dtype = op.dtypes[i]; | |||
| layout.init_contiguous_stride(); | |||
| input_descs.push_back({layout, op.devices[i]}); | |||
| } | |||
| auto [output_descs, valid] = | |||
| OpDef::infer_output_attrs_fallible(*op.op, input_descs); | |||
| mgb_assert(valid, "shape inference incomplete"); | |||
| auto apply_on_physical_tensor( | |||
| const OpDef& def, const SmallVector<TensorPtr>& inputs, | |||
| SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) { | |||
| mgb_assert(validated, "shape inference incomplete"); | |||
| SmallVector<TensorPtr> outputs; | |||
| for (auto&& output_desc : output_descs) { | |||
| HostTensorND shape_tensor{ | |||
| @@ -189,7 +174,9 @@ auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
| return opr::Identity::make(inputs[0], config); | |||
| } | |||
| auto apply_on_physical_tensor(const OpDef& def, const SmallVector<TensorPtr>& inputs) { | |||
| auto apply_on_physical_tensor( | |||
| const OpDef& def, const SmallVector<TensorPtr>& inputs, | |||
| SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) { | |||
| return SmallVector<TensorPtr>{inputs[0]}; | |||
| } | |||
| OP_TRAIT_REG(Identity, Identity) | |||
| @@ -588,7 +575,9 @@ ComputingGraphHolder<Kind>& get_computing_graph( | |||
| return *cg_holder_queue.back(); | |||
| } | |||
| auto apply_on_physical_tensor(const OpDef& def, const SmallVector<TensorPtr>& inputs) { | |||
| auto apply_on_physical_tensor( | |||
| const OpDef& def, const SmallVector<TensorPtr>& inputs, | |||
| SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) { | |||
| SmallVector<LogicalTensorDesc> input_descs; | |||
| for (auto&& input : inputs) { | |||
| input_descs.push_back({input->layout(), input->comp_node()}); | |||
| @@ -451,7 +451,14 @@ public: | |||
| } | |||
| } else { | |||
| if (dep.type == cg::static_infer::DepType::SHAPE) { | |||
| if (auto* val = infer(output_data[dep.idx].shape_infer, sync)) { | |||
| // using opr->output()->shape when it's available | |||
| // otherwise infer it | |||
| if (!owner.m_opr->output(dep.idx)->shape().is_empty()) { | |||
| target.inp_val.val[i].m_shape = | |||
| &owner.m_opr->output(dep.idx)->shape(); | |||
| } else if ( | |||
| auto* val = | |||
| infer(output_data[dep.idx].shape_infer, sync)) { | |||
| target.inp_val.val[i].m_shape = val; | |||
| } else | |||
| return false; | |||
| @@ -798,7 +805,8 @@ public: | |||
| } | |||
| SmallVector<TensorPtr> apply_on_physical_tensor( | |||
| const OpDef& def, SmallVector<TensorPtr> inputs) { | |||
| const OpDef& def, SmallVector<TensorPtr> inputs, | |||
| SmallVector<LogicalTensorDesc>& desc, const bool& validated) { | |||
| auto raw_inputs = to_raw_ptr_array(inputs); | |||
| auto& minigraph = get_cached_minigraph(def, raw_inputs); | |||
| auto _ = scoped_attach(&minigraph); | |||
| @@ -811,10 +819,12 @@ public: | |||
| // LogicalTensorDesc for minigraph.opr()->usable_output() | |||
| SmallVector<LogicalTensorDesc> output_descs; | |||
| for (size_t i = 0; i < minigraph.opr()->output().size(); ++i) { | |||
| auto* var = minigraph.opr()->output()[i]; | |||
| auto* shape = sess.infer(sess.output_data[i].shape_infer, true); | |||
| mgb_assert(shape); | |||
| minigraph.opr()->output()[i]->shape(*shape); | |||
| var->shape(*shape); | |||
| } | |||
| for (size_t i = 0; i < minigraph.output_size(); ++i) { | |||
| auto* ovar = minigraph.output_var(i); | |||
| mgb_assert(ovar->dtype().valid() && ovar->comp_node().valid()); | |||
| @@ -829,6 +839,7 @@ public: | |||
| outputs[i] = | |||
| Tensor::make(output_descs[i].layout, output_descs[i].comp_node); | |||
| } | |||
| auto raw_outputs = to_raw_ptr_array(outputs); | |||
| CompNode::UnorderedSet used_cns; | |||
| for (auto&& out : raw_outputs) { | |||
| @@ -843,6 +854,7 @@ public: | |||
| } | |||
| } | |||
| } | |||
| // some opr (e.g. Subtensor) may invoke infer_value during execution, | |||
| // so we need create inference session here | |||
| minigraph.execute(raw_inputs, raw_outputs, m_env); | |||
| @@ -853,6 +865,7 @@ public: | |||
| } | |||
| } | |||
| } | |||
| return outputs; | |||
| } | |||
| }; | |||
| @@ -27,9 +27,10 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||
| } | |||
| SmallVector<TensorPtr> apply_on_physical_tensor( | |||
| const OpDef& def, SmallVector<TensorPtr> inputs) { | |||
| auto ret = | |||
| proxy_graph::ProxyGraphTypeI::inst().apply_on_physical_tensor(def, inputs); | |||
| const OpDef& def, SmallVector<TensorPtr> inputs, | |||
| SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) { | |||
| auto ret = proxy_graph::ProxyGraphTypeI::inst().apply_on_physical_tensor( | |||
| def, inputs, output_descs, validated); | |||
| return ret; | |||
| } | |||
| @@ -62,15 +62,19 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||
| } | |||
| SmallVector<TensorPtr> apply_on_physical_tensor( | |||
| const OpDef& def, SmallVector<TensorPtr> inputs) { | |||
| const OpDef& def, SmallVector<TensorPtr> inputs, | |||
| SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) { | |||
| SmallVector<LogicalTensorDesc> input_descs; | |||
| for (auto&& input : inputs) { | |||
| input_descs.push_back({input->layout(), input->comp_node()}); | |||
| } | |||
| auto subgraph = def.trait()->make_forward_graph(def, input_descs); | |||
| auto apply_functor = [](const std::shared_ptr<OpDef>& op, | |||
| const SmallVector<TensorPtr>& inputs, size_t nr_outputs) { | |||
| return OpDef::apply_on_physical_tensor(*op, inputs); | |||
| auto apply_functor = [&output_descs]( | |||
| const std::shared_ptr<OpDef>& op, | |||
| const SmallVector<TensorPtr>& inputs, | |||
| size_t nr_outputs) { | |||
| // do not use infered output_desc in subgraph | |||
| return OpDef::apply_on_physical_tensor(*op, inputs, output_descs, false); | |||
| }; | |||
| auto const_functor = [&](const TensorPtr& value) { return value; }; | |||
| auto outputs = subgraph.apply<TensorPtr>(inputs, apply_functor, const_functor); | |||
| @@ -77,7 +77,9 @@ void TensorSanityCheck::enable() { | |||
| std::move(trait.apply_on_physical_tensor)); | |||
| trait.apply_on_physical_tensor = ApplyOnPhysicalTensor( | |||
| [this, backup = backup.get()]( | |||
| const OpDef& def, const SmallVector<TensorPtr>& inputs) { | |||
| const OpDef& def, const SmallVector<TensorPtr>& inputs, | |||
| SmallVector<LogicalTensorDesc>& output_descs, | |||
| const bool& validated) { | |||
| for (auto&& i : inputs) { | |||
| if (!m_checker->check(i)) { | |||
| mgb_throw( | |||
| @@ -86,7 +88,7 @@ void TensorSanityCheck::enable() { | |||
| print_op(def).c_str()); | |||
| } | |||
| } | |||
| auto output = (*backup)(def, inputs); | |||
| auto output = (*backup)(def, inputs, output_descs, validated); | |||
| for (auto&& i : output) { | |||
| mgb_assert(m_checker->check(i)); | |||
| } | |||
| @@ -51,7 +51,8 @@ public: | |||
| const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs); | |||
| static SmallVector<TensorPtr> apply_on_physical_tensor( | |||
| const OpDef& def, SmallVector<TensorPtr> inputs); | |||
| const OpDef& def, SmallVector<TensorPtr> inputs, | |||
| SmallVector<LogicalTensorDesc>& output_descs, const bool& validated); | |||
| /*! | |||
| * \brief Call the corresponding dnn op to calculate results. Output | |||
| @@ -18,7 +18,8 @@ namespace imperative { | |||
| namespace proxy_graph_detail { | |||
| SmallVector<TensorPtr> apply_on_physical_tensor( | |||
| const OpDef& def, SmallVector<TensorPtr> inputs); | |||
| const OpDef& def, SmallVector<TensorPtr> inputs, | |||
| SmallVector<LogicalTensorDesc>& output_descs, const bool& validated); | |||
| std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||
| const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs); | |||
| @@ -18,7 +18,8 @@ namespace imperative { | |||
| namespace subgraph_detail { | |||
| SmallVector<TensorPtr> apply_on_physical_tensor( | |||
| const OpDef& def, SmallVector<TensorPtr> inputs); | |||
| const OpDef& def, SmallVector<TensorPtr> inputs, | |||
| SmallVector<LogicalTensorDesc>& output_descs, const bool& validated); | |||
| std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||
| const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs); | |||
| @@ -81,7 +81,13 @@ T prepare_optimized_backward_inputs( | |||
| SmallVector<TensorPtr> apply_shared_on_physical_tensor( | |||
| std::shared_ptr<OpDef> def, SmallVector<TensorPtr> inputs, size_t nr_outputs) { | |||
| return OpDef::apply_on_physical_tensor(*def, inputs); | |||
| SmallVector<LogicalTensorDesc> input_descs; | |||
| for (auto&& i : inputs) { | |||
| input_descs.push_back({i->layout(), i->comp_node()}); | |||
| } | |||
| auto [output_descs, validated] = | |||
| OpDef::infer_output_attrs_fallible(*def, input_descs); | |||
| return OpDef::apply_on_physical_tensor(*def, inputs, output_descs, validated); | |||
| } | |||
| TEST(TestImperative, BackwardGraphBasic) { | |||
| @@ -106,7 +112,13 @@ TEST(TestImperative, BackwardGraphBasic) { | |||
| auto&& save_for_backward = result.input_mask; | |||
| auto&& input_has_grad = result.output_mask; | |||
| auto outputs = OpDef::apply_on_physical_tensor(*attr, inputs); | |||
| for (size_t i = 0; i < inputs.size(); i++) { | |||
| input_descs[i].value = inputs[i]->dev_tensor(); | |||
| } | |||
| auto [output_descs, validated] = | |||
| OpDef::infer_output_attrs_fallible(*attr, input_descs); | |||
| auto outputs = | |||
| OpDef::apply_on_physical_tensor(*attr, inputs, output_descs, validated); | |||
| inputs.push_back(outputs[0]); | |||
| hvs.push_back(*gen({42})); | |||
| inputs.push_back(Tensor::make(hvs.back())); | |||
| @@ -161,7 +173,10 @@ TEST(TestImperative, BackwardGraphIdentity) { | |||
| auto&& save_for_backward = result.input_mask; | |||
| auto&& input_has_grad = result.output_mask; | |||
| auto outputs = OpDef::apply_on_physical_tensor(*attr, inputs); | |||
| auto [output_descs, validated] = | |||
| OpDef::infer_output_attrs_fallible(*attr, input_descs); | |||
| auto outputs = | |||
| OpDef::apply_on_physical_tensor(*attr, inputs, output_descs, validated); | |||
| inputs.push_back(outputs[0]); | |||
| inputs.push_back(dc); | |||
| mgb_assert(save_for_backward.size() == inputs.size()); | |||
| @@ -238,7 +253,13 @@ TEST(TestImperative, OptimizedBackwardGraphBasic) { | |||
| auto a_tn = Tensor::make(*a_hv); | |||
| auto b_tn = Tensor::make(*b_hv); | |||
| auto dc_tn = Tensor::make(*dc_hv); | |||
| auto c_tn = OpDef::apply_on_physical_tensor(*op, {a_tn, b_tn})[0]; | |||
| SmallVector<LogicalTensorDesc> input_descs; | |||
| input_descs.push_back({a_tn->layout(), a_tn->comp_node(), a_tn->dev_tensor()}); | |||
| input_descs.push_back({b_tn->layout(), b_tn->comp_node(), b_tn->dev_tensor()}); | |||
| auto [output_descs, validated] = | |||
| OpDef::infer_output_attrs_fallible(*op, input_descs); | |||
| auto c_tn = OpDef::apply_on_physical_tensor( | |||
| *op, {a_tn, b_tn}, output_descs, validated)[0]; | |||
| auto backward_graph_inputs = prepare_backward_graph_inputs<SmallVector<TensorPtr>>( | |||
| bg, {a_tn, b_tn}, {c_tn}, {dc_tn}); | |||
| @@ -35,7 +35,8 @@ TEST(TestImperative, AllReduceBasic) { | |||
| megdnn::param::CollectiveComm::Mode::ALL_REDUCE_SUM, "all_reduce", 2, | |||
| idx, idx == 0, false, server_addr, port, dtype::Float32(), "nccl", ""); | |||
| auto inp = Tensor::make(*hnd); | |||
| auto oup = OpDef::apply_on_physical_tensor(*def, {inp}); | |||
| SmallVector<LogicalTensorDesc> output_descs; | |||
| auto oup = OpDef::apply_on_physical_tensor(*def, {inp}, output_descs, false); | |||
| HostTensorND host_v; | |||
| host_v.copy_from(oup[0]->dev_tensor()).sync(); | |||
| MGB_ASSERT_TENSOR_NEAR(*expect, host_v, 1e-6); | |||
| @@ -135,7 +135,9 @@ void OprChecker::run(std::vector<InputSpec> inp_keys, std::set<size_t> bypass) { | |||
| imp_physical_inp[i] = Tensor::make(host_inp[i]); | |||
| } | |||
| auto imp_oup = OpDef::apply_on_physical_tensor(*m_op, imp_physical_inp); | |||
| SmallVector<LogicalTensorDesc> output_descs; | |||
| auto imp_oup = OpDef::apply_on_physical_tensor( | |||
| *m_op, imp_physical_inp, output_descs, false); | |||
| mgb_assert(imp_oup.size() == nr_oups); | |||
| // check input not modified | |||
| @@ -122,7 +122,10 @@ void run_graph(size_t mem_reserved) { | |||
| Param param{Param::Mode::MUL}; | |||
| attr.param.write_pod(param); | |||
| auto out = OpDef::apply_on_physical_tensor(*op, {ptr_a[1], ptr_a[99]}).at(0); | |||
| SmallVector<LogicalTensorDesc> output_descs; | |||
| auto out = OpDef::apply_on_physical_tensor( | |||
| *op, {ptr_a[1], ptr_a[99]}, output_descs, false) | |||
| .at(0); | |||
| // value before defrag | |||
| HostTensorND host_out_before; | |||
| @@ -36,7 +36,8 @@ TEST(TestImperative, IORemote) { | |||
| auto def = imperative::RemoteSend::make( | |||
| "io_remote_test", server_addr, port, 1, "nccl"); | |||
| auto inp = Tensor::make(*hnd); | |||
| auto oup = OpDef::apply_on_physical_tensor(*def, {inp}); | |||
| SmallVector<LogicalTensorDesc> output_descs; | |||
| auto oup = OpDef::apply_on_physical_tensor(*def, {inp}, output_descs, false); | |||
| }; | |||
| auto run_recv = [&](std::shared_ptr<HostTensorND> hnd) { | |||
| @@ -44,7 +45,8 @@ TEST(TestImperative, IORemote) { | |||
| "io_remote_test", server_addr, port, 0, CompNode::load("gpu1"), | |||
| std::vector<int32_t>{(int32_t)vector_size}, dtype::Float32(), "nccl"); | |||
| auto inp = Tensor::make(*hnd); | |||
| auto oup = OpDef::apply_on_physical_tensor(*def, {inp}); | |||
| SmallVector<LogicalTensorDesc> output_descs; | |||
| auto oup = OpDef::apply_on_physical_tensor(*def, {inp}, output_descs, false); | |||
| HostTensorND host_v; | |||
| host_v.copy_from(oup[0]->dev_tensor()).sync(); | |||
| MGB_ASSERT_TENSOR_NEAR(*expect, host_v, 1e-6); | |||
| @@ -25,7 +25,14 @@ void check_rng_basic(Args&&... args) { | |||
| DeviceTensorND tshape_dev; | |||
| cg::copy_shape_to_tensor_value(tshape_dev, tshape); | |||
| SmallVector<TensorPtr> inputs = {Tensor::make(tshape_dev)}; | |||
| auto outputs = OpDef::apply_on_physical_tensor(*op, inputs); | |||
| SmallVector<LogicalTensorDesc> input_descs; | |||
| input_descs.push_back( | |||
| {inputs[0]->layout(), inputs[0]->comp_node(), | |||
| inputs[0]->dev_tensor()}); | |||
| auto [output_descs, validated] = | |||
| OpDef::infer_output_attrs_fallible(*op, input_descs); | |||
| auto outputs = OpDef::apply_on_physical_tensor( | |||
| *op, inputs, output_descs, validated); | |||
| ASSERT_TRUE(outputs[0]->layout().eq_shape(tshape)); | |||
| ASSERT_TRUE(cn == outputs[0]->comp_node()); | |||
| // sync before delete handle | |||
| @@ -41,7 +48,14 @@ void check_rng_with_input_basic( | |||
| const CompNode& cn, const SmallVector<TensorPtr>& inputs, Args&&... args) { | |||
| Handle h = new_handle(cn, 123); | |||
| auto op = Op::make(std::forward<Args>(args)..., h); | |||
| auto outputs = OpDef::apply_on_physical_tensor(*op, inputs); | |||
| SmallVector<LogicalTensorDesc> input_descs; | |||
| for (auto&& i : inputs) { | |||
| input_descs.push_back({i->layout(), i->comp_node(), i->dev_tensor()}); | |||
| } | |||
| auto [output_descs, validated] = | |||
| OpDef::infer_output_attrs_fallible(*op, input_descs); | |||
| auto outputs = | |||
| OpDef::apply_on_physical_tensor(*op, inputs, output_descs, validated); | |||
| ASSERT_TRUE(outputs[0]->layout().eq_shape(inputs[0]->shape())); | |||
| ASSERT_TRUE(cn == outputs[0]->comp_node()); | |||
| // sync before delete handle | |||
| @@ -142,7 +142,8 @@ public: | |||
| const TensorLayout& layout() const { return m_layout; } | |||
| MemAllocPlan& layout(const TensorLayout& dest, bool allow_shape_change = false); | |||
| MGE_WIN_DECLSPEC_FUC MemAllocPlan& layout( | |||
| const TensorLayout& dest, bool allow_shape_change = false); | |||
| #if MGB_ENABLE_JSON | |||
| MGE_WIN_DECLSPEC_FUC std::shared_ptr<json::Value> to_json() const override; | |||