|
|
@@ -19,9 +19,6 @@ using namespace cuda; |
|
|
using namespace conv_bias; |
|
|
using namespace conv_bias; |
|
|
|
|
|
|
|
|
bool ConvBiasForwardImpl::AlgoCUDNNConv::is_available(const SizeArgs& args) const { |
|
|
bool ConvBiasForwardImpl::AlgoCUDNNConv::is_available(const SizeArgs& args) const { |
|
|
if (args.z_layout->ndim > 0) |
|
|
|
|
|
return false; |
|
|
|
|
|
|
|
|
|
|
|
if (args.filter_meta.format != Param::Format::NCHW && |
|
|
if (args.filter_meta.format != Param::Format::NCHW && |
|
|
args.filter_meta.format != Param::Format::NHWC) { |
|
|
args.filter_meta.format != Param::Format::NHWC) { |
|
|
if (!args.src_layout->is_contiguous() || !args.dst_layout->is_contiguous()) { |
|
|
if (!args.src_layout->is_contiguous() || !args.dst_layout->is_contiguous()) { |
|
|
@@ -75,6 +72,15 @@ WorkspaceBundle ConvBiasForwardImpl::AlgoCUDNNConv::get_workspace_bundle( |
|
|
sizes.push_back(dst_layout.span().dist_byte()); |
|
|
sizes.push_back(dst_layout.span().dist_byte()); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
if (args.z_layout->ndim > 0 && |
|
|
|
|
|
args.z_layout->dtype.enumv() != args.bias_layout->dtype.enumv()) { |
|
|
|
|
|
auto z_layout = *args.z_layout; |
|
|
|
|
|
z_layout.dtype = DType(); |
|
|
|
|
|
args.opr->check_or_deduce_dtype_fwd( |
|
|
|
|
|
args.src_layout->dtype, args.filter_layout->dtype, z_layout.dtype); |
|
|
|
|
|
sizes.push_back(z_layout.span().dist_byte()); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
SizeArgs conv_args = args; |
|
|
SizeArgs conv_args = args; |
|
|
conv_args.dst_layout = &dst_layout; |
|
|
conv_args.dst_layout = &dst_layout; |
|
|
|
|
|
|
|
|
@@ -129,6 +135,22 @@ void ConvBiasForwardImpl::AlgoCUDNNConv::exec(const ExecArgs& args) const { |
|
|
cudnnGetErrorString(status), conv_args.to_string().c_str()); |
|
|
cudnnGetErrorString(status), conv_args.to_string().c_str()); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
if (args.z_layout->ndim > 0) { |
|
|
|
|
|
auto z_tensor = *args.z_tensor; |
|
|
|
|
|
if (args.z_layout->dtype.enumv() != args.bias_layout->dtype.enumv()) { |
|
|
|
|
|
z_tensor.raw_ptr = bundle.get(2); |
|
|
|
|
|
z_tensor.layout.dtype = DType(); |
|
|
|
|
|
args.opr->check_or_deduce_dtype_fwd( |
|
|
|
|
|
args.src_layout->dtype, args.filter_layout->dtype, |
|
|
|
|
|
z_tensor.layout.dtype); |
|
|
|
|
|
auto typecvt = args.handle->create_operator<TypeCvt>(); |
|
|
|
|
|
typecvt->exec(*args.z_tensor, z_tensor); |
|
|
|
|
|
} |
|
|
|
|
|
auto add = args.handle->create_operator<ElemwiseForward>(); |
|
|
|
|
|
add->param().mode = Elemwise::Param::Mode::ADD; |
|
|
|
|
|
add->exec({conv_dst_tensor, z_tensor}, conv_dst_tensor); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
handle_bias_and_nonlinear( |
|
|
handle_bias_and_nonlinear( |
|
|
args.handle, args.nonlinear_mode, &conv_dst_tensor, args.dst_tensor, |
|
|
args.handle, args.nonlinear_mode, &conv_dst_tensor, args.dst_tensor, |
|
|
args.bias_tensor); |
|
|
args.bias_tensor); |
|
|
|