From: @huaweib Reviewed-by: @kisnwang,@jjfeing Signed-off-by: @jjfeingtags/v1.1.0
| @@ -38,16 +38,15 @@ void MKLCPUKernel::GetPadding(const CNodePtr &kernel_node, const std::string &pa | |||||
| for (size_t i = 0; i < weight_height.size(); ++i) { | for (size_t i = 0; i < weight_height.size(); ++i) { | ||||
| auto wh = weight_height[i]; | auto wh = weight_height[i]; | ||||
| int re = wh % stride; | int re = wh % stride; | ||||
| int pad_along; | |||||
| if (re == 0) { | if (re == 0) { | ||||
| re = stride; | |||||
| } | |||||
| int pad = kernel_size[i] - re; | |||||
| padding_l->emplace_back(pad / 2); | |||||
| if (pad % 2 == 0) { | |||||
| padding_r->emplace_back(pad / 2); | |||||
| pad_along = std::max(SizeToInt(kernel_size[i]) - stride, 0); | |||||
| } else { | } else { | ||||
| padding_r->emplace_back(pad / 2 + 1); | |||||
| pad_along = std::max(SizeToInt(kernel_size[i]) - re, 0); | |||||
| } | } | ||||
| int pad = pad_along / 2; | |||||
| padding_l->emplace_back(pad); | |||||
| padding_r->emplace_back(pad_along - pad); | |||||
| } | } | ||||
| } else if (pad_mode == PAD_MODE_LOWER_VALID || pad_mode == PAD_MODE_UPPER_VALID) { | } else if (pad_mode == PAD_MODE_LOWER_VALID || pad_mode == PAD_MODE_UPPER_VALID) { | ||||
| MS_LOG(INFO) << "pad valid"; | MS_LOG(INFO) << "pad valid"; | ||||
| @@ -257,7 +257,8 @@ void CPUKernelRuntime::BindInputTensorAddressPtr(const session::KernelGraph &ker | |||||
| MS_EXCEPTION_IF_NULL(address); | MS_EXCEPTION_IF_NULL(address); | ||||
| MS_EXCEPTION_IF_NULL(tensor); | MS_EXCEPTION_IF_NULL(tensor); | ||||
| if (tensor_address != nullptr && tensor_address != address && | if (tensor_address != nullptr && tensor_address != address && | ||||
| std::dynamic_pointer_cast<device::DeviceAddress>(tensor_address)->DeviceType() != DeviceAddressType::kCPU) { | |||||
| (std::dynamic_pointer_cast<device::DeviceAddress>(tensor_address)->DeviceType() != DeviceAddressType::kCPU || | |||||
| AnfAlgo::IsParameterWeight(item->cast<ParameterPtr>()))) { | |||||
| tensor->data_sync(false); | tensor->data_sync(false); | ||||
| } | } | ||||
| if (GetTypeByte(TypeIdToType(tensor->data_type())) == GetTypeByte(TypeIdToType(address->type_id_))) { | if (GetTypeByte(TypeIdToType(tensor->data_type())) == GetTypeByte(TypeIdToType(address->type_id_))) { | ||||
| @@ -21,7 +21,7 @@ Examples: | |||||
| >>> import mindspore.ops as ops | >>> import mindspore.ops as ops | ||||
| """ | """ | ||||
| from .primitive import Primitive, PrimitiveWithInfer, prim_attr_register | |||||
| from .primitive import Primitive, PrimitiveWithInfer, PrimitiveWithCheck, prim_attr_register | |||||
| from .vm_impl_registry import get_vm_impl_fn, vm_impl_registry | from .vm_impl_registry import get_vm_impl_fn, vm_impl_registry | ||||
| from .op_info_register import op_info_register, AkgGpuRegOp, AkgAscendRegOp, AiCPURegOp, TBERegOp, DataType | from .op_info_register import op_info_register, AkgGpuRegOp, AkgAscendRegOp, AiCPURegOp, TBERegOp, DataType | ||||
| from .primitive import constexpr | from .primitive import constexpr | ||||
| @@ -32,7 +32,7 @@ from .operations import * | |||||
| from .functional import * | from .functional import * | ||||
| __primitive__ = [ | __primitive__ = [ | ||||
| "prim_attr_register", "Primitive", "PrimitiveWithInfer", "signature" | |||||
| "prim_attr_register", "Primitive", "PrimitiveWithInfer", "PrimitiveWithCheck", "signature" | |||||
| ] | ] | ||||
| __all__ = ["get_vm_impl_fn", "vm_impl_registry", | __all__ = ["get_vm_impl_fn", "vm_impl_registry", | ||||