From: @yonibaehr_admin Reviewed-by: @HilbertDavid,@ddwsky Signed-off-by: @HilbertDavidtags/v1.1.0
| @@ -20,58 +20,60 @@ | |||||
| #include "nnacl/fp32_grad/activation_grad.h" | #include "nnacl/fp32_grad/activation_grad.h" | ||||
| #include "nnacl/errorcode.h" | #include "nnacl/errorcode.h" | ||||
| inline int ReluGrad(float *src0, float *src1, int length, float *dst) { | |||||
| for (int i = 0; i < length; ++i) { | |||||
| dst[i] = src1[i] > 0 ? 1.0f : 0.0f; | |||||
| inline int ReluGrad(float *src0, float *src1, size_t length, float *dst) { | |||||
| for (size_t i = 0; i < length; ++i) { | |||||
| if (src1[i] > 0) { | |||||
| dst[i] = src0[i]; | |||||
| } else { | |||||
| dst[i] = 0; | |||||
| } | |||||
| } | } | ||||
| ElementMul(src0, dst, dst, length); | |||||
| return NNACL_OK; | return NNACL_OK; | ||||
| } | } | ||||
| int Relu6Grad(float *src0, float *src1, int length, float *dst) { | |||||
| for (int i = 0; i < length; ++i) { | |||||
| if (src1[i] < 0) { | |||||
| dst[i] = 0; | |||||
| int Relu6Grad(float *src0, float *src1, size_t length, float *dst) { | |||||
| for (size_t i = 0; i < length; ++i) { | |||||
| if (src1[i] > 0.0f && src1[i] <= 6.0f) { | |||||
| dst[i] = src0[i]; | |||||
| } else { | } else { | ||||
| dst[i] = src1[i] > 6.0f ? 0.0f : 1.0f; | |||||
| dst[i] = 0.0f; | |||||
| } | } | ||||
| } | } | ||||
| ElementMul(src0, dst, dst, length); | |||||
| return NNACL_OK; | return NNACL_OK; | ||||
| } | } | ||||
| int LReluGrad(float *src0, float *src1, int length, float *dst, float alpha) { | |||||
| for (int i = 0; i < length; ++i) { | |||||
| int LReluGrad(float *src0, float *src1, size_t length, float *dst, float alpha) { | |||||
| for (size_t i = 0; i < length; ++i) { | |||||
| dst[i] = src1[i] > 0.0f ? 1.0f : alpha; | dst[i] = src1[i] > 0.0f ? 1.0f : alpha; | ||||
| } | } | ||||
| ElementMul(src0, dst, dst, length); | ElementMul(src0, dst, dst, length); | ||||
| return NNACL_OK; | return NNACL_OK; | ||||
| } | } | ||||
| int SigmoidGrad(float *src0, float *src1, int length, float *dst) { | |||||
| for (int i = 0; i < length; ++i) { | |||||
| int SigmoidGrad(float *src0, float *src1, size_t length, float *dst) { | |||||
| for (size_t i = 0; i < length; ++i) { | |||||
| dst[i] = src0[i] * (src1[i] * (1.0f - src1[i])); | dst[i] = src0[i] * (src1[i] * (1.0f - src1[i])); | ||||
| } | } | ||||
| return NNACL_OK; | return NNACL_OK; | ||||
| } | } | ||||
| int TanhGrad(float *src0, float *src1, int length, float *dst) { | |||||
| for (int i = 0; i < length; ++i) { | |||||
| int TanhGrad(float *src0, float *src1, size_t length, float *dst) { | |||||
| for (size_t i = 0; i < length; ++i) { | |||||
| dst[i] = (1.0f - (src1[i] * src1[i])) * src0[i]; | dst[i] = (1.0f - (src1[i] * src1[i])) * src0[i]; | ||||
| } | } | ||||
| return NNACL_OK; | return NNACL_OK; | ||||
| } | } | ||||
| int HSwishGrad(float *src0, float *src1, int length, float *dst) { | |||||
| for (int i = 0; i < length; ++i) { | |||||
| int HSwishGrad(float *src0, float *src1, size_t length, float *dst) { | |||||
| for (size_t i = 0; i < length; ++i) { | |||||
| float tmp = (src1[i] > 3.0f ? 1.0f : (src1[i] < -3.0f ? 0.0f : (2.0f * src1[i] + 3.0f) / 6.0f)); | float tmp = (src1[i] > 3.0f ? 1.0f : (src1[i] < -3.0f ? 0.0f : (2.0f * src1[i] + 3.0f) / 6.0f)); | ||||
| dst[i] = tmp * src0[i]; | dst[i] = tmp * src0[i]; | ||||
| } | } | ||||
| return NNACL_OK; | return NNACL_OK; | ||||
| } | } | ||||
| int HSigmoidGrad(float *src0, float *src1, int length, float *dst) { | |||||
| for (int i = 0; i < length; ++i) { | |||||
| int HSigmoidGrad(float *src0, float *src1, size_t length, float *dst) { | |||||
| for (size_t i = 0; i < length; ++i) { | |||||
| float tmp = (src1[i] > 3.0f ? 0.0f : (src1[i] < -3.0f ? 0.0f : 1.0f / 6.0f)); | float tmp = (src1[i] > 3.0f ? 0.0f : (src1[i] < -3.0f ? 0.0f : 1.0f / 6.0f)); | ||||
| dst[i] = tmp * src0[i]; | dst[i] = tmp * src0[i]; | ||||
| } | } | ||||
| @@ -30,13 +30,13 @@ typedef struct ActivationGradParameter { | |||||
| extern "C" { | extern "C" { | ||||
| #endif | #endif | ||||
| int ReluGrad(float *src0, float *src1, int length, float *dst); | |||||
| int Relu6Grad(float *src0, float *src1, int length, float *dst); | |||||
| int LReluGrad(float *src0, float *src1, int length, float *dst, float alpha); | |||||
| int SigmoidGrad(float *src0, float *src1, int length, float *dst); | |||||
| int TanhGrad(float *src0, float *src1, int length, float *dst); | |||||
| int HSwishGrad(float *src0, float *src1, int length, float *dst); | |||||
| int HSigmoidGrad(float *src0, float *src1, int length, float *dst); | |||||
| int ReluGrad(float *src0, float *src1, size_t length, float *dst); | |||||
| int Relu6Grad(float *src0, float *src1, size_t length, float *dst); | |||||
| int LReluGrad(float *src0, float *src1, size_t length, float *dst, float alpha); | |||||
| int SigmoidGrad(float *src0, float *src1, size_t length, float *dst); | |||||
| int TanhGrad(float *src0, float *src1, size_t length, float *dst); | |||||
| int HSwishGrad(float *src0, float *src1, size_t length, float *dst); | |||||
| int HSigmoidGrad(float *src0, float *src1, size_t length, float *dst); | |||||
| #ifdef __cplusplus | #ifdef __cplusplus | ||||
| } | } | ||||
| @@ -34,21 +34,21 @@ void AvgPoolingGrad(const float *input_ptr, float *output_ptr, PoolingParameter | |||||
| memset(output_ptr, 0, in_h * in_w * channel * output_batch * sizeof(float)); | memset(output_ptr, 0, in_h * in_w * channel * output_batch * sizeof(float)); | ||||
| float kk = (float)(win_h * win_w); | float kk = (float)(win_h * win_w); | ||||
| for (uint16_t ib = 0; ib < output_batch; ib++) { | |||||
| for (int ib = 0; ib < output_batch; ib++) { | |||||
| float *out = &output_ptr[(ib * in_h * in_w * channel)]; | float *out = &output_ptr[(ib * in_h * in_w * channel)]; | ||||
| const float *inPtr = &input_ptr[(ib * output_h * output_w * channel)]; | const float *inPtr = &input_ptr[(ib * output_h * output_w * channel)]; | ||||
| // iterate over yt | // iterate over yt | ||||
| for (uint16_t yh = 0; yh < output_h; yh++) { | |||||
| for (uint16_t yw = 0; yw < output_w; yw++) { | |||||
| for (uint16_t ic = 0; ic < channel; ic++) { | |||||
| for (int yh = 0; yh < output_h; yh++) { | |||||
| for (int yw = 0; yw < output_w; yw++) { | |||||
| for (int ic = 0; ic < channel; ic++) { | |||||
| int idx = (yw + yh * output_w) * channel + ic; | int idx = (yw + yh * output_w) * channel + ic; | ||||
| float delta = inPtr[idx] / kk; | float delta = inPtr[idx] / kk; | ||||
| for (int32_t kh = 0; kh < win_h; kh++) { | |||||
| for (int kh = 0; kh < win_h; kh++) { | |||||
| int xh = yh * stride_h + kh - pad_h; | int xh = yh * stride_h + kh - pad_h; | ||||
| if ((xh < 0) || (xh >= in_h)) { | if ((xh < 0) || (xh >= in_h)) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| for (int32_t kw = 0; kw < win_w; kw++) { | |||||
| for (int kw = 0; kw < win_w; kw++) { | |||||
| int xw = yw * stride_w + kw - pad_w; | int xw = yw * stride_w + kw - pad_w; | ||||
| if ((xw < 0) || (xw >= in_w)) { | if ((xw < 0) || (xw >= in_w)) { | ||||
| continue; | continue; | ||||
| @@ -78,25 +78,25 @@ void MaxPoolingGrad(const float *input_ptr, const float *dx_ptr, const float *dy | |||||
| int output_batch = pooling_param->output_batch_; | int output_batch = pooling_param->output_batch_; | ||||
| memset(output_ptr, 0, in_h * in_w * channel * output_batch * sizeof(float)); | memset(output_ptr, 0, in_h * in_w * channel * output_batch * sizeof(float)); | ||||
| for (uint16_t ib = 0; ib < output_batch; ib++) { | |||||
| for (int ib = 0; ib < output_batch; ib++) { | |||||
| float *out = &output_ptr[(ib * in_h * in_w * channel)]; | float *out = &output_ptr[(ib * in_h * in_w * channel)]; | ||||
| const float *inPtr = (const float *)(&input_ptr[(ib * in_h * in_w * channel)]); | const float *inPtr = (const float *)(&input_ptr[(ib * in_h * in_w * channel)]); | ||||
| const float *dyPtr = (const float *)(&dy_ptr[(ib * output_h * output_w * channel)]); | const float *dyPtr = (const float *)(&dy_ptr[(ib * output_h * output_w * channel)]); | ||||
| for (uint16_t yh = 0; yh < output_h; yh++) { | |||||
| for (uint16_t yw = 0; yw < output_w; yw++) { | |||||
| for (uint16_t ic = 0; ic < channel; ic++) { | |||||
| for (int yh = 0; yh < output_h; yh++) { | |||||
| for (int yw = 0; yw < output_w; yw++) { | |||||
| for (int ic = 0; ic < channel; ic++) { | |||||
| int idx = (yw + yh * output_w) * channel + ic; | int idx = (yw + yh * output_w) * channel + ic; | ||||
| float delta = dyPtr[idx]; | float delta = dyPtr[idx]; | ||||
| float max_val = -FLT_MAX; | float max_val = -FLT_MAX; | ||||
| int max_idx = 0; | int max_idx = 0; | ||||
| for (int32_t kh = 0; kh < win_h; kh++) { | |||||
| for (int kh = 0; kh < win_h; kh++) { | |||||
| int xh = yh * stride_h + kh - pad_h; | int xh = yh * stride_h + kh - pad_h; | ||||
| if ((xh < 0) || (xh >= in_h)) { | if ((xh < 0) || (xh >= in_h)) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| for (int32_t kw = 0; kw < win_w; kw++) { | |||||
| for (int kw = 0; kw < win_w; kw++) { | |||||
| int xw = yw * stride_w + kw - pad_w; | int xw = yw * stride_w + kw - pad_w; | ||||
| if ((xw < 0) || (xw >= in_w)) { | if ((xw < 0) || (xw >= in_w)) { | ||||
| continue; | continue; | ||||
| @@ -160,7 +160,7 @@ void Conv2D::PopulaterConv2DMultiGroup(const Primitive &prim, schema::PrimitiveT | |||||
| #endif | #endif | ||||
| auto kernel_size = CastToInt(prim.GetAttr("kernel_size")); | auto kernel_size = CastToInt(prim.GetAttr("kernel_size")); | ||||
| attr->kernelH = kernel_size.at(0); | attr->kernelH = kernel_size.at(0); | ||||
| attr->kernelW = kernel_size.at(1); | |||||
| attr->kernelW = (kernel_size.size() > 1) ? kernel_size.at(1) : kernel_size.at(0); | |||||
| auto stride = CastToInt(prim.GetAttr("stride")); | auto stride = CastToInt(prim.GetAttr("stride")); | ||||
| attr->strideH = stride.at(2); | attr->strideH = stride.at(2); | ||||
| @@ -240,7 +240,7 @@ void Conv2D::PopulaterConv2DSingleGroup(const Primitive &prim, schema::Primitive | |||||
| auto kernel_size = CastToInt(prim.GetAttr("kernel_size")); | auto kernel_size = CastToInt(prim.GetAttr("kernel_size")); | ||||
| attr->kernelH = kernel_size.at(0); | attr->kernelH = kernel_size.at(0); | ||||
| attr->kernelW = kernel_size.at(1); | |||||
| attr->kernelW = (kernel_size.size() > 1) ? kernel_size.at(1) : kernel_size.at(0); | |||||
| auto stride = CastToInt(prim.GetAttr("stride")); | auto stride = CastToInt(prim.GetAttr("stride")); | ||||
| attr->strideH = stride.at(2); | attr->strideH = stride.at(2); | ||||
| @@ -104,22 +104,22 @@ int Conv2DGradFilter::UnPackAttr(const Primitive &prim, const std::vector<AnfNod | |||||
| attr->format = schema::Format_NUM_OF_FORMAT; | attr->format = schema::Format_NUM_OF_FORMAT; | ||||
| } | } | ||||
| auto pad_list = CastToInt(prim.GetAttr("pad_list")); | auto pad_list = CastToInt(prim.GetAttr("pad_list")); | ||||
| attr->padUp = pad_list[0]; | |||||
| attr->padDown = pad_list[1]; | |||||
| attr->padLeft = pad_list[2]; | |||||
| attr->padRight = pad_list[3]; | |||||
| attr->padUp = pad_list.at(0); | |||||
| attr->padDown = pad_list.at(1); | |||||
| attr->padLeft = pad_list.at(2); | |||||
| attr->padRight = pad_list.at(3); | |||||
| auto dilation = CastToInt(prim.GetAttr("dilation")); | auto dilation = CastToInt(prim.GetAttr("dilation")); | ||||
| attr->dilateH = dilation[2]; | |||||
| attr->dilateW = dilation[3]; | |||||
| attr->dilateH = dilation.at(2); | |||||
| attr->dilateW = dilation.at(3); | |||||
| auto kernel_size = CastToInt(prim.GetAttr("kernel_size")); | auto kernel_size = CastToInt(prim.GetAttr("kernel_size")); | ||||
| attr->kernelH = kernel_size[0]; | |||||
| attr->kernelW = kernel_size[1]; | |||||
| attr->kernelH = kernel_size.at(0); | |||||
| attr->kernelW = (kernel_size.size() > 1) ? kernel_size.at(1) : kernel_size.at(0); | |||||
| auto stride = CastToInt(prim.GetAttr("stride")); | auto stride = CastToInt(prim.GetAttr("stride")); | ||||
| attr->strideH = stride[0]; | |||||
| attr->strideW = stride[1]; | |||||
| attr->strideH = stride.at(0); | |||||
| attr->strideW = stride.at(1); | |||||
| attr->channelOut = CastToInt(prim.GetAttr("out_channel")).front(); | attr->channelOut = CastToInt(prim.GetAttr("out_channel")).front(); | ||||
| auto pad_mode = GetValue<std::string>(prim.GetAttr("pad_mode")); | auto pad_mode = GetValue<std::string>(prim.GetAttr("pad_mode")); | ||||
| @@ -105,22 +105,22 @@ int Conv2DGradInput::UnPackAttr(const Primitive &prim, const std::vector<AnfNode | |||||
| attr->format = schema::Format_NUM_OF_FORMAT; | attr->format = schema::Format_NUM_OF_FORMAT; | ||||
| } | } | ||||
| auto pad_list = CastToInt(prim.GetAttr("pad_list")); | auto pad_list = CastToInt(prim.GetAttr("pad_list")); | ||||
| attr->padUp = pad_list[0]; | |||||
| attr->padDown = pad_list[1]; | |||||
| attr->padLeft = pad_list[2]; | |||||
| attr->padRight = pad_list[3]; | |||||
| attr->padUp = pad_list.at(0); | |||||
| attr->padDown = pad_list.at(1); | |||||
| attr->padLeft = pad_list.at(2); | |||||
| attr->padRight = pad_list.at(3); | |||||
| auto dilation = CastToInt(prim.GetAttr("dilation")); | auto dilation = CastToInt(prim.GetAttr("dilation")); | ||||
| attr->dilateH = dilation[2]; | |||||
| attr->dilateW = dilation[3]; | |||||
| attr->dilateH = dilation.at(2); | |||||
| attr->dilateW = dilation.at(3); | |||||
| auto kernel_size = CastToInt(prim.GetAttr("kernel_size")); | auto kernel_size = CastToInt(prim.GetAttr("kernel_size")); | ||||
| attr->kernelH = kernel_size[0]; | |||||
| attr->kernelW = kernel_size[1]; | |||||
| attr->kernelH = kernel_size.at(0); | |||||
| attr->kernelW = (kernel_size.size() > 1) ? kernel_size.at(1) : kernel_size.at(0); | |||||
| auto stride = CastToInt(prim.GetAttr("stride")); | auto stride = CastToInt(prim.GetAttr("stride")); | ||||
| attr->strideH = stride[0]; | |||||
| attr->strideW = stride[1]; | |||||
| attr->strideH = stride.at(0); | |||||
| attr->strideW = stride.at(1); | |||||
| attr->channelOut = CastToInt(prim.GetAttr("out_channel")).front(); | attr->channelOut = CastToInt(prim.GetAttr("out_channel")).front(); | ||||
| @@ -84,10 +84,15 @@ int PoolingGrad::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> | |||||
| } else { | } else { | ||||
| attr->format = schema::Format_NUM_OF_FORMAT; | attr->format = schema::Format_NUM_OF_FORMAT; | ||||
| } | } | ||||
| if (prim.instance_name() == "MaxPoolGrad") { | if (prim.instance_name() == "MaxPoolGrad") { | ||||
| attr->poolingMode = schema::PoolMode_MAX_POOLING; | attr->poolingMode = schema::PoolMode_MAX_POOLING; | ||||
| } else if (prim.instance_name() == "MeanPoolGrad") { | |||||
| } else if (prim.instance_name() == "AvgPoolGrad") { | |||||
| attr->poolingMode = schema::PoolMode_MEAN_POOLING; | |||||
| } else if (prim.instance_name() == "AvgPoolGradGpu") { | |||||
| attr->poolingMode = schema::PoolMode_MEAN_POOLING; | attr->poolingMode = schema::PoolMode_MEAN_POOLING; | ||||
| } else { | |||||
| attr->poolingMode = schema::PoolMode_MAX_POOLING; | |||||
| } | } | ||||
| auto pad_mode = GetValue<std::string>(prim.GetAttr("padding")); | auto pad_mode = GetValue<std::string>(prim.GetAttr("padding")); | ||||
| @@ -609,7 +609,7 @@ std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std: | |||||
| } else if ((op_type == "ReluGrad" || op_type == "ReLU6Grad" || op_type == "SigmoidGrad" || | } else if ((op_type == "ReluGrad" || op_type == "ReLU6Grad" || op_type == "SigmoidGrad" || | ||||
| op_type == "HSigmoidGrad" || op_type == "HSwishGrad")) { | op_type == "HSigmoidGrad" || op_type == "HSwishGrad")) { | ||||
| return NewPrimitiveC<ActivationGrad>(prim, inputs, quantType); | return NewPrimitiveC<ActivationGrad>(prim, inputs, quantType); | ||||
| } else if ((op_type == "MaxPoolGrad") || (op_type == "MeanPoolGrad") || (op_type == "AvgPoolGradGpu")) { | |||||
| } else if ((op_type == "MaxPoolGrad") || (op_type == "AvgPoolGrad") || (op_type == "AvgPoolGradGpu")) { | |||||
| return NewPrimitiveC<PoolingGrad>(prim, inputs, quantType); | return NewPrimitiveC<PoolingGrad>(prim, inputs, quantType); | ||||
| } else if (op_type == "Conv2DBackpropFilter") { | } else if (op_type == "Conv2DBackpropFilter") { | ||||
| return NewPrimitiveC<Conv2DGradFilter>(prim, inputs, quantType); | return NewPrimitiveC<Conv2DGradFilter>(prim, inputs, quantType); | ||||
| @@ -35,10 +35,6 @@ int PoolingGradCPUKernel::Init() { | |||||
| auto in_shape = in_tensors_.at(0)->shape(); | auto in_shape = in_tensors_.at(0)->shape(); | ||||
| auto out_shape = in_tensors_.at(1)->shape(); | auto out_shape = in_tensors_.at(1)->shape(); | ||||
| if (pool_param->pool_mode_ == PoolMode_AvgPool) { | |||||
| in_shape = in_tensors_.at(1)->shape(); | |||||
| out_shape = in_tensors_.at(0)->shape(); | |||||
| } | |||||
| int input_h = in_shape.at(1); | int input_h = in_shape.at(1); | ||||
| int input_w = in_shape.at(2); | int input_w = in_shape.at(2); | ||||
| @@ -71,6 +67,7 @@ int PoolingGradCPUKernel::Execute(int task_id) { | |||||
| auto dy_ptr = reinterpret_cast<float *>(in_tensors_.at(2)->MutableData()); | auto dy_ptr = reinterpret_cast<float *>(in_tensors_.at(2)->MutableData()); | ||||
| MaxPoolingGrad(input_ptr, dx_ptr, dy_ptr, output_ptr, pool_param, task_id); | MaxPoolingGrad(input_ptr, dx_ptr, dy_ptr, output_ptr, pool_param, task_id); | ||||
| } else { | } else { | ||||
| input_ptr = reinterpret_cast<float *>(in_tensors_.at(2)->MutableData()); | |||||
| AvgPoolingGrad(input_ptr, output_ptr, pool_param, task_id); | AvgPoolingGrad(input_ptr, output_ptr, pool_param, task_id); | ||||
| } | } | ||||
| return RET_OK; | return RET_OK; | ||||
| @@ -88,15 +85,6 @@ int PoolingGradImpl(void *cdata, int task_id) { | |||||
| } | } | ||||
| int PoolingGradCPUKernel::Run() { | int PoolingGradCPUKernel::Run() { | ||||
| // clear output buffer before parallel run | |||||
| PoolingParameter *pooling_param = reinterpret_cast<PoolingParameter *>(op_parameter_); | |||||
| auto output_ptr = reinterpret_cast<float *>(out_tensors_.at(0)->MutableData()); | |||||
| int size = | |||||
| pooling_param->input_w_ * pooling_param->input_h_ * pooling_param->input_channel_ * pooling_param->output_batch_; | |||||
| for (int i = 0; i < size; i++) { | |||||
| output_ptr[i] = 0.0; | |||||
| } | |||||
| int error_code = ParallelLaunch(this->context_->thread_pool_, PoolingGradImpl, this, 1); | int error_code = ParallelLaunch(this->context_->thread_pool_, PoolingGradImpl, this, 1); | ||||
| if (error_code != RET_OK) { | if (error_code != RET_OK) { | ||||
| MS_LOG(ERROR) << "pooling error error_code[" << error_code << "]"; | MS_LOG(ERROR) << "pooling error error_code[" << error_code << "]"; | ||||
| @@ -30,7 +30,7 @@ TrainModel *TrainModel::Import(const char *model_buf, size_t size) { | |||||
| flatbuffers::Verifier verify((const uint8_t *)model_buf, size); | flatbuffers::Verifier verify((const uint8_t *)model_buf, size); | ||||
| int schema_version = VersionVerify(&verify); | int schema_version = VersionVerify(&verify); | ||||
| if (schema_version == -1) { | if (schema_version == -1) { | ||||
| MS_LOG(ERROR) << "The buffer is invalid and fail to create graph."; | |||||
| MS_LOG(ERROR) << "The model buffer is invalid, cannot get schema version"; | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| TrainModel *model = new (std::nothrow) TrainModel(); | TrainModel *model = new (std::nothrow) TrainModel(); | ||||
| @@ -378,7 +378,7 @@ session::TrainSession *session::TrainSession::CreateSession(const std::string &f | |||||
| ifs.seekg(0, std::ios::end); | ifs.seekg(0, std::ios::end); | ||||
| auto size = ifs.tellg(); | auto size = ifs.tellg(); | ||||
| if (size == 0) { | |||||
| if (size <= 0) { | |||||
| MS_LOG(ERROR) << "Could not read file " << filename; | MS_LOG(ERROR) << "Could not read file " << filename; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| @@ -391,8 +391,12 @@ session::TrainSession *session::TrainSession::CreateSession(const std::string &f | |||||
| ifs.seekg(0, std::ios::beg); | ifs.seekg(0, std::ios::beg); | ||||
| ifs.read(buf.get(), size); | ifs.read(buf.get(), size); | ||||
| if (!ifs) { | |||||
| MS_LOG(ERROR) << "only read " << ifs.gcount() << "bytes in " << filename; | |||||
| ifs.close(); | |||||
| return nullptr; | |||||
| } | |||||
| ifs.close(); | ifs.close(); | ||||
| return session::TrainSession::CreateSession(buf.get(), size, context, train_mode); | return session::TrainSession::CreateSession(buf.get(), size, context, train_mode); | ||||
| } | } | ||||
| @@ -1,5 +1,5 @@ | |||||
| mini_alexnet | mini_alexnet | ||||
| # mobilenetv1 | |||||
| #mobilenetv1 | |||||
| mobilenetv2 | mobilenetv2 | ||||
| mobilenetv3 | mobilenetv3 | ||||
| lenet | lenet | ||||