From: @yonibaehr_admin Reviewed-by: @hangangqiang,@HilbertDavid Signed-off-by: @hangangqiangpull/15084/MERGE
| @@ -15,70 +15,130 @@ | |||
| */ | |||
| #include "nnacl/fp32_grad/resize_grad.h" | |||
| #include <math.h> | |||
| #include "nnacl/infer/common_infer.h" | |||
| void ResizeNearestNeighborGrad(float *in_addr, float *out_addr, int batch_size, int channel, | |||
| void ResizeNearestNeighborGrad(float *in_addr, float *out_addr, int batch_size, int channel, int format, | |||
| ResizeGradParameter *param) { | |||
| bool align_corners = param->align_corners_; | |||
| size_t in_hw_size = param->in_width_ * param->in_height_; | |||
| size_t out_hw_size = param->out_width_ * param->out_height_; | |||
| for (int32_t b = 0; b < batch_size; ++b) { | |||
| for (size_t i = 0; i < in_hw_size; ++i) { | |||
| size_t in_y = i / param->in_width_; | |||
| size_t in_x = i % param->in_width_; | |||
| if (format == Format_NHWC) { | |||
| for (int32_t b = 0; b < batch_size; ++b) { | |||
| for (size_t i = 0; i < in_hw_size; ++i) { | |||
| size_t in_y = i / param->in_width_; | |||
| size_t in_x = i % param->in_width_; | |||
| for (int32_t c = 0; c < channel; ++c) { | |||
| size_t out_y = MSMIN( | |||
| (align_corners) ? (size_t)roundf(in_y * param->height_scale_) : (size_t)floorf(in_y * param->height_scale_), | |||
| param->out_height_ - 1); | |||
| size_t out_x = MSMIN( | |||
| (align_corners) ? (size_t)roundf(in_x * param->width_scale_) : (size_t)floorf(in_x * param->width_scale_), | |||
| param->out_width_ - 1); | |||
| size_t out_offset = out_y * (param->out_width_ * channel) + (out_x * channel) + c; | |||
| size_t in_offset = in_y * (param->in_width_ * channel) + (in_x * channel) + c; | |||
| out_addr[out_offset] += in_addr[in_offset]; | |||
| } | |||
| } | |||
| out_addr += out_hw_size * channel; | |||
| in_addr += in_hw_size * channel; | |||
| } | |||
| } else if (format == Format_NCHW) { | |||
| for (int32_t b = 0; b < batch_size; ++b) { | |||
| for (int32_t c = 0; c < channel; ++c) { | |||
| size_t out_y = MSMIN( | |||
| (align_corners) ? (size_t)roundf(in_y * param->height_scale_) : (size_t)floorf(in_y * param->height_scale_), | |||
| param->out_height_ - 1); | |||
| size_t out_x = MSMIN( | |||
| (align_corners) ? (size_t)roundf(in_x * param->width_scale_) : (size_t)floorf(in_x * param->width_scale_), | |||
| param->out_width_ - 1); | |||
| size_t out_offset = out_y * (param->out_width_ * channel) + (out_x * channel) + c; | |||
| size_t in_offset = in_y * (param->in_width_ * channel) + (in_x * channel) + c; | |||
| out_addr[out_offset] += in_addr[in_offset]; | |||
| for (size_t h = 0; h < param->in_height_; ++h) { | |||
| size_t out_y = | |||
| MSMIN((align_corners) ? (size_t)roundf(h * param->height_scale_) : (size_t)floorf(h * param->height_scale_), | |||
| param->out_height_ - 1); | |||
| for (size_t w = 0; w < param->in_width_; ++w) { | |||
| size_t out_x = | |||
| MSMIN((align_corners) ? (size_t)roundf(w * param->width_scale_) : (size_t)floorf(w * param->width_scale_), | |||
| param->out_width_ - 1); | |||
| out_addr[out_y * param->out_width_ + out_x] += in_addr[h * param->in_width_ + w]; | |||
| } | |||
| } | |||
| out_addr += out_hw_size; | |||
| in_addr += in_hw_size; | |||
| } | |||
| } | |||
| out_addr += out_hw_size * channel; | |||
| in_addr += in_hw_size * channel; | |||
| } | |||
| } | |||
| void ResizeBiLinearGrad(float *in_addr, float *out_addr, int batch_size, int channel, ResizeGradParameter *param) { | |||
| void ResizeBiLinearGrad(float *in_addr, float *out_addr, int batch_size, int channel, int format, | |||
| ResizeGradParameter *param) { | |||
| size_t in_hw_size = param->in_width_ * param->in_height_; | |||
| size_t out_hw_size = param->out_width_ * param->out_height_; | |||
| for (int32_t b = 0; b < batch_size; ++b) { | |||
| for (size_t i = 0; i < in_hw_size; ++i) { | |||
| size_t h = i / param->in_width_; | |||
| size_t w = i % param->in_width_; | |||
| for (int32_t c = 0; c < channel; ++c) { | |||
| const float in_y = (float)h * param->height_scale_; | |||
| size_t top_y_index = MSMAX((size_t)(floorf(in_y)), (size_t)(0)); | |||
| size_t bottom_y_index = MSMIN((size_t)(ceilf(in_y)), param->out_height_ - 1); | |||
| const float y_lerp = in_y - floorf(in_y); | |||
| const float inverse_y_lerp = 1.0 - y_lerp; | |||
| if (format == Format_NHWC) { | |||
| for (int32_t b = 0; b < batch_size; ++b) { | |||
| for (size_t i = 0; i < in_hw_size; ++i) { | |||
| size_t h = i / param->in_width_; | |||
| size_t w = i % param->in_width_; | |||
| for (int32_t c = 0; c < channel; ++c) { | |||
| float in_y = (float)h * param->height_scale_; | |||
| size_t top_y_index = MSMAX((size_t)(floorf(in_y)), (size_t)(0)); | |||
| size_t bottom_y_index = MSMIN((size_t)(ceilf(in_y)), param->out_height_ - 1); | |||
| float y_lerp = in_y - floorf(in_y); | |||
| float inverse_y_lerp = 1.0 - y_lerp; | |||
| const float in_x = (float)w * param->width_scale_; | |||
| size_t left_x_index = MSMAX((size_t)(floorf(in_x)), (size_t)(0)); | |||
| size_t right_x_index = MSMIN((size_t)(ceilf(in_x)), param->out_width_ - 1); | |||
| const float x_lerp = in_x - floorf(in_x); | |||
| const float inverse_x_lerp = 1.0 - x_lerp; | |||
| float in_x = (float)w * param->width_scale_; | |||
| size_t left_x_index = MSMAX((size_t)(floorf(in_x)), (size_t)(0)); | |||
| size_t right_x_index = MSMIN((size_t)(ceilf(in_x)), param->out_width_ - 1); | |||
| float x_lerp = in_x - floorf(in_x); | |||
| float inverse_x_lerp = 1.0 - x_lerp; | |||
| size_t in_offset = h * (param->in_width_ * channel) + (w * channel) + c; | |||
| size_t out_offset_top_y_left_x = top_y_index * (param->out_width_ * channel) + (left_x_index * channel) + c; | |||
| size_t out_offset_top_y_right_x = top_y_index * (param->out_width_ * channel) + (right_x_index * channel) + c; | |||
| size_t out_offset_bottom_y_left_x = | |||
| bottom_y_index * (param->out_width_ * channel) + (left_x_index * channel) + c; | |||
| size_t out_offset_bottom_y_right_x = | |||
| bottom_y_index * (param->out_width_ * channel) + (right_x_index * channel) + c; | |||
| size_t in_offset = h * (param->in_width_ * channel) + (w * channel) + c; | |||
| size_t out_offset_top_y_left_x = top_y_index * (param->out_width_ * channel) + (left_x_index * channel) + c; | |||
| size_t out_offset_top_y_right_x = top_y_index * (param->out_width_ * channel) + (right_x_index * channel) + c; | |||
| size_t out_offset_bottom_y_left_x = | |||
| bottom_y_index * (param->out_width_ * channel) + (left_x_index * channel) + c; | |||
| size_t out_offset_bottom_y_right_x = | |||
| bottom_y_index * (param->out_width_ * channel) + (right_x_index * channel) + c; | |||
| out_addr[out_offset_top_y_left_x] += in_addr[in_offset] * (float)(inverse_y_lerp * inverse_x_lerp); | |||
| out_addr[out_offset_top_y_right_x] += in_addr[in_offset] * (float)(inverse_y_lerp * x_lerp); | |||
| out_addr[out_offset_bottom_y_left_x] += in_addr[in_offset] * (float)(y_lerp * inverse_x_lerp); | |||
| out_addr[out_offset_bottom_y_right_x] += in_addr[in_offset] * (float)(y_lerp * x_lerp); | |||
| } | |||
| } | |||
| out_addr += out_hw_size * channel; | |||
| in_addr += in_hw_size * channel; | |||
| } | |||
| } else if (format == Format_NCHW) { | |||
| size_t in_height = param->in_height_; | |||
| size_t in_width = param->in_width_; | |||
| size_t out_height = param->out_height_; | |||
| size_t out_width = param->out_width_; | |||
| out_hw_size = out_height * out_width; | |||
| in_hw_size = in_height * in_width; | |||
| out_addr[out_offset_top_y_left_x] += in_addr[in_offset] * (float)(inverse_y_lerp * inverse_x_lerp); | |||
| out_addr[out_offset_top_y_right_x] += in_addr[in_offset] * (float)(inverse_y_lerp * x_lerp); | |||
| out_addr[out_offset_bottom_y_left_x] += in_addr[in_offset] * (float)(y_lerp * inverse_x_lerp); | |||
| out_addr[out_offset_bottom_y_right_x] += in_addr[in_offset] * (float)(y_lerp * x_lerp); | |||
| for (size_t b = 0; b < batch_size; ++b) { | |||
| for (size_t c = 0; c < channel; ++c) { | |||
| for (size_t h = 0; h < in_height; ++h) { | |||
| const float in_y = (float)(h)*param->height_scale_; | |||
| const size_t top_y_index = MSMAX((size_t)floorf(in_y), 0); | |||
| const size_t bottom_y_index = MSMIN((size_t)ceilf(in_y), out_height - 1); | |||
| const float y_lerp = in_y - floorf(in_y); | |||
| const float inverse_y_lerp = 1.0 - y_lerp; | |||
| for (size_t w = 0; w < in_width; ++w) { | |||
| const float in_x = (float)(w)*param->width_scale_; | |||
| const size_t left_x_index = MSMAX((size_t)floorf(in_x), 0); | |||
| const size_t right_x_index = MSMIN((size_t)ceilf(in_x), out_width - 1); | |||
| const float x_lerp = in_x - floorf(in_x); | |||
| const float inverse_x_lerp = 1.0 - x_lerp; | |||
| out_addr[top_y_index * out_width + left_x_index] += | |||
| in_addr[h * in_width + w] * (float)(inverse_y_lerp * inverse_x_lerp); | |||
| out_addr[top_y_index * out_width + right_x_index] += | |||
| in_addr[h * in_width + w] * (float)(inverse_y_lerp * x_lerp); | |||
| out_addr[bottom_y_index * out_width + left_x_index] += | |||
| in_addr[h * in_width + w] * (float)(y_lerp * inverse_x_lerp); | |||
| out_addr[bottom_y_index * out_width + right_x_index] += | |||
| in_addr[h * in_width + w] * (float)(y_lerp * x_lerp); | |||
| } | |||
| } | |||
| out_addr += out_hw_size; | |||
| in_addr += in_hw_size; | |||
| } | |||
| } | |||
| out_addr += out_hw_size * channel; | |||
| in_addr += in_hw_size * channel; | |||
| } | |||
| } | |||
| @@ -35,9 +35,10 @@ typedef struct ResizeGradParameter { | |||
| float width_scale_; | |||
| } ResizeGradParameter; | |||
| void ResizeNearestNeighborGrad(float *in_addr, float *out_addr, int batch_size, int channel, | |||
| void ResizeNearestNeighborGrad(float *in_addr, float *out_addr, int batch_size, int channel, int format, | |||
| ResizeGradParameter *param); | |||
| void ResizeBiLinearGrad(float *in_addr, float *out_addr, int batch_size, int channel, ResizeGradParameter *param); | |||
| void ResizeBiLinearGrad(float *in_addr, float *out_addr, int batch_size, int channel, int format, | |||
| ResizeGradParameter *param); | |||
| #ifdef __cplusplus | |||
| } | |||
| #endif | |||
| @@ -0,0 +1,57 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "nnacl/infer/resize_grad_infer.h" | |||
| #include "nnacl/infer/infer_register.h" | |||
| int ResizeGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, | |||
| OpParameter *parameter) { | |||
| #ifdef Debug | |||
| int check_ret = CheckAugmentNull(inputs, inputs_size, outputs, outputs_size, parameter); | |||
| if (check_ret != NNACL_OK) { | |||
| return check_ret; | |||
| } | |||
| #endif | |||
| const TensorC *input = inputs[0]; | |||
| if (input->shape_size_ != 4) { | |||
| return NNACL_ERR; | |||
| } | |||
| TensorC *output = outputs[0]; | |||
| SetDataTypeFormat(output, input); | |||
| if (!parameter->infer_flag_) { | |||
| return NNACL_INFER_INVALID; | |||
| } | |||
| const TensorC *input_1 = inputs[1]; | |||
| if (input_1->shape_size_ == 4) { | |||
| ShapeSet(output->shape_, &output->shape_size_, input_1->shape_, input_1->shape_size_); | |||
| } else if (input_1->shape_size_ == 1 && input_1->shape_[0] == 2 && input_1->data_type_ == kNumberTypeInt32) { | |||
| int output_shape[MAX_SHAPE_SIZE]; | |||
| size_t output_shape_size = 0; | |||
| int32_t *data = (int32_t *)(input_1->data_); | |||
| ShapePush(output_shape, &output_shape_size, GetBatch(input)); | |||
| ShapePush(output_shape, &output_shape_size, data[0]); | |||
| ShapePush(output_shape, &output_shape_size, data[1]); | |||
| ShapePush(output_shape, &output_shape_size, GetChannel(input)); | |||
| SetShapeArray(output, output_shape, output_shape_size); | |||
| } else { | |||
| return NNACL_ERR; | |||
| } | |||
| return NNACL_OK; | |||
| } | |||
| REG_INFER(ResizeGrad, PrimType_ResizeGrad, ResizeGradInferShape) | |||
| @@ -0,0 +1,32 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_NNACL_RESIZE_GRAD_INFER_H_ | |||
| #define MINDSPORE_NNACL_RESIZE_GRAD_INFER_H_ | |||
| #include "nnacl/infer/common_infer.h" | |||
| #include "nnacl/fp32_grad/resize_grad.h" | |||
| #ifdef __cplusplus | |||
| extern "C" { | |||
| #endif | |||
| int ResizeGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, | |||
| OpParameter *parameter); | |||
| #ifdef __cplusplus | |||
| } | |||
| #endif | |||
| #endif // MINDSPORE_NNACL_RESIZE_GRAD_INFER_H_ | |||
| @@ -22,11 +22,9 @@ import mindspore.common.dtype as mstype | |||
| from mindspore import context, Tensor, nn | |||
| from mindspore.train.serialization import export | |||
| sys.path.append(os.environ['CLOUD_MODEL_ZOO'] + 'official/cv/densenet121/') | |||
| sys.path.append(os.environ['CLOUD_MODEL_ZOO'] + 'official/cv/densenet/') | |||
| #pylint: disable=wrong-import-position | |||
| from official.cv.densenet121.src.network.densenet import DenseNet121 | |||
| from src.network.densenet import DenseNet121 | |||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU", save_graphs=False) | |||
| @@ -40,6 +40,7 @@ using mindspore::dataset::transforms::TypeCast; | |||
| using mindspore::dataset::vision::Normalize; | |||
| using mindspore::dataset::vision::Resize; | |||
| using mindspore::lite::AccuracyMetrics; | |||
| using mindspore::lite::Model; | |||
| using mindspore::session::TrainLoopCallBack; | |||
| using mindspore::session::TrainLoopCallBackData; | |||
| @@ -100,7 +101,10 @@ void NetRunner::InitAndFigureInputs() { | |||
| context.thread_num_ = 2; | |||
| model_ = mindspore::lite::Model::Import(ms_file_.c_str()); | |||
| MS_ASSERT(nullptr != model_); | |||
| if (model_ == nullptr) { | |||
| std::cout << "import model failed" << std::endl; | |||
| return; | |||
| } | |||
| session_ = mindspore::session::TrainSession::CreateSession(model_, &context, true); | |||
| MS_ASSERT(nullptr != session_); | |||
| @@ -184,7 +188,7 @@ int NetRunner::Main() { | |||
| if (epochs_ > 0) { | |||
| auto trained_fn = ms_file_.substr(0, ms_file_.find_last_of('.')) + "_trained.ms"; | |||
| mindspore::lite::Model::Export(model_, trained_fn.c_str()); | |||
| Model::Export(model_, trained_fn.c_str()); | |||
| } | |||
| return 0; | |||
| } | |||
| @@ -121,25 +121,11 @@ int LiteKernel::PreProcess() { | |||
| } | |||
| int LiteKernel::PostProcess() { | |||
| #ifdef SUPPORT_TRAIN | |||
| for (auto input_kernel : this->in_kernels()) { | |||
| MS_ASSERT(input_kernel != nullptr); | |||
| if (input_kernel->is_model_output()) { | |||
| continue; | |||
| } | |||
| auto ret = input_kernel->DecOutTensorRefCount(); | |||
| if (0 != ret) { | |||
| MS_LOG(WARNING) << "DecOutTensorRefCount for kernel" << this->name() << " failed"; | |||
| } | |||
| } | |||
| return RET_OK; | |||
| #else | |||
| for (auto *output : this->out_tensors()) { | |||
| MS_ASSERT(output != nullptr); | |||
| output->ResetRefCount(); | |||
| } | |||
| return FreeInWorkTensor(); | |||
| #endif | |||
| } | |||
| int LiteKernel::Run(const KernelCallBack &before, const KernelCallBack &after) { | |||
| @@ -901,6 +901,9 @@ RegistryMSOps g_reduceFusionPrimitiveCreatorRegistry("ReduceFusion", ReduceFusio | |||
| RegistryMSOps g_reshapePrimitiveCreatorRegistry("Reshape", ReshapePrimitiveCreator); | |||
| RegistryMSOps g_resizePrimitiveCreatorRegistry("Resize", ResizePrimitiveCreator); | |||
| RegistryMSOps g_resizeGradPrimitiveCreatorRegistry("ResizeGrad", ResizeGradPrimitiveCreator); | |||
| RegistryMSOps g_resizeBilinearGradPrimitiveCreatorRegistry("ResizeBilinearGrad", ResizeGradPrimitiveCreator); | |||
| RegistryMSOps g_resizeNearestNeighborGradPrimitiveCreatorRegistry("ResizeNearestNeighborGrad", | |||
| ResizeGradPrimitiveCreator); | |||
| RegistryMSOps g_reverseV2PrimitiveCreatorRegistry("ReverseV2", ReverseV2PrimitiveCreator); | |||
| RegistryMSOps g_reverseSequencePrimitiveCreatorRegistry("ReverseSequence", ReverseSequencePrimitiveCreator); | |||
| RegistryMSOps g_rfftPrimitiveCreatorRegistry("Rfft", RfftPrimitiveCreator); | |||
| @@ -36,27 +36,45 @@ int SelectCPUKernel::Run() { | |||
| auto bool_tensor = in_tensors_.front(); | |||
| MS_ASSERT(bool_tensor != nullptr); | |||
| MS_ASSERT(bool_tensor->data_type() == kNumberTypeBool); | |||
| MS_ASSERT(bool_tensor->Size() == 1); | |||
| MS_ASSERT(bool_tensor->Size() == 1); | |||
| auto condition = static_cast<bool *>(bool_tensor->data_c()); | |||
| if (condition == nullptr) { | |||
| MS_LOG(ERROR) << "data of bool tensor is nullptr"; | |||
| return lite::RET_NULL_PTR; | |||
| } | |||
| if (*condition) { | |||
| auto ret = MoveData(this->out_tensors_.begin(), this->out_tensors_.end(), this->in_tensors_.begin() + 1, | |||
| this->in_tensors_.begin() + 1 + this->out_tensors_.size()); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "carry data error : " << ret; | |||
| return ret; | |||
| if (bool_tensor->Size() == 1) { | |||
| auto condition = static_cast<bool *>(bool_tensor->data_c()); | |||
| if (condition == nullptr) { | |||
| MS_LOG(ERROR) << "data of bool tensor is nullptr"; | |||
| return lite::RET_NULL_PTR; | |||
| } | |||
| if (*condition) { | |||
| auto ret = MoveData(this->out_tensors_.begin(), this->out_tensors_.end(), this->in_tensors_.begin() + 1, | |||
| this->in_tensors_.begin() + 1 + this->out_tensors_.size()); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "carry data error : " << ret; | |||
| return ret; | |||
| } | |||
| } else { | |||
| auto ret = MoveData(this->out_tensors_.begin(), this->out_tensors_.end(), | |||
| this->in_tensors_.begin() + 1 + this->out_tensors_.size(), | |||
| this->in_tensors_.begin() + 1 + 2 * this->out_tensors_.size()); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "carry data error : " << ret; | |||
| return ret; | |||
| } | |||
| } | |||
| } else { | |||
| auto ret = MoveData(this->out_tensors_.begin(), this->out_tensors_.end(), | |||
| this->in_tensors_.begin() + 1 + this->out_tensors_.size(), | |||
| this->in_tensors_.begin() + 1 + 2 * this->out_tensors_.size()); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "carry data error : " << ret; | |||
| return ret; | |||
| MS_ASSERT(bool_tensor->shape().size() == in_tensors_.at(1)->shape().size()); | |||
| for (size_t i = 0; i < in_tensors_.at(1)->shape().size(); i++) { | |||
| if (bool_tensor->shape()[i] != in_tensors_.at(1)->shape()[i]) { | |||
| MS_LOG(ERROR) << "Tensor shapes differ in dim: " << i << " in_tensors_.at(0): " << bool_tensor->shape()[i] | |||
| << " in_tensors_.at(1): " << in_tensors_.at(1)->shape()[i]; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| MS_ASSERT(in_tensors_.at(1)->Size() == out_tensors_.at(0)->Size()); | |||
| auto size = in_tensors_.at(1)->ElementsNum(); | |||
| auto condition = static_cast<bool *>(bool_tensor->data_c()); | |||
| auto input1 = static_cast<float *>(in_tensors_.at(1)->data_c()); | |||
| auto input2 = static_cast<float *>(in_tensors_.at(2)->data_c()); | |||
| auto output = static_cast<float *>(out_tensors_.at(0)->data_c()); | |||
| for (int i = 0; i < size; i++) { | |||
| output[i] = condition[i] ? input1[i] : input2[i]; | |||
| } | |||
| } | |||
| return RET_OK; | |||
| @@ -71,9 +71,9 @@ int ResizeGradCPUKernel::Execute(int task_id) { | |||
| auto channel = in_tensors_.at(0)->Channel(); | |||
| if (param->method == static_cast<int>(schema::ResizeMethod_NEAREST)) { | |||
| ResizeNearestNeighborGrad(in_addr, out_addr, batch_size, channel, param); | |||
| ResizeNearestNeighborGrad(in_addr, out_addr, batch_size, channel, in_tensors_.at(0)->format(), param); | |||
| } else { | |||
| ResizeBiLinearGrad(in_addr, out_addr, batch_size, channel, param); | |||
| ResizeBiLinearGrad(in_addr, out_addr, batch_size, channel, in_tensors_.at(0)->format(), param); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| @@ -153,9 +153,7 @@ int TrainSession::RunGraph(const KernelCallBack &before, const KernelCallBack &a | |||
| MS_LOG(ERROR) << "CheckInputs failed"; | |||
| return ret; | |||
| } | |||
| for (auto out_tensor : outputs_) { // increase RefCount of output tensors, such that Run will not free them | |||
| out_tensor->set_ref_count(out_tensor->ref_count() + 1); | |||
| } | |||
| for (auto *kernel : run_kernel) { | |||
| MS_ASSERT(nullptr != kernel); | |||
| ret = kernel->PreProcess(); | |||
| @@ -340,6 +340,9 @@ int NetTrain::RunExportedNet() { | |||
| std::cout << "CreateSession failed while running " << model_name.c_str() << std::endl; | |||
| return RET_ERROR; | |||
| } | |||
| if (flags_->loss_name_ != "") { | |||
| session_->SetLossName(flags_->loss_name_); | |||
| } | |||
| ms_inputs_ = session_->GetInputs(); | |||
| auto end_prepare_time = GetTimeUs(); | |||
| MS_LOG(INFO) << "Exported model PrepareTime = " << (end_prepare_time - start_prepare_time) / 1000 << " ms"; | |||
| @@ -409,6 +412,9 @@ int NetTrain::RunNetTrain() { | |||
| return RET_ERROR; | |||
| } | |||
| if (flags_->loss_name_ != "") { | |||
| session_->SetLossName(flags_->loss_name_); | |||
| } | |||
| session_->Train(); | |||
| ms_inputs_ = session_->GetInputs(); | |||
| @@ -536,10 +542,23 @@ int NetTrain::InitCallbackParameter() { | |||
| op_times_by_name_[call_param.node_name].first++; | |||
| op_times_by_name_[call_param.node_name].second += cost; | |||
| if (layer_checksum_) { | |||
| float *output = reinterpret_cast<float *>(after_outputs.at(0)->MutableData()); | |||
| float sum = 0; | |||
| for (int i = 0; i < after_outputs.at(0)->ElementsNum(); i++) sum += output[i]; | |||
| std::cout << call_param.node_type << " shape= " << after_outputs.at(0)->shape() << " sum=" << sum << "\n"; | |||
| auto out_tensor = after_outputs.at(0); | |||
| void *output = out_tensor->MutableData(); | |||
| int tensor_size = out_tensor->ElementsNum(); | |||
| TypeId type = out_tensor->data_type(); | |||
| std::cout << call_param.node_type << " shape=" << after_outputs.at(0)->shape() << " sum="; | |||
| switch (type) { | |||
| case kNumberTypeFloat32: | |||
| std::cout << TensorSum<float>(output, tensor_size); | |||
| break; | |||
| case kNumberTypeInt32: | |||
| std::cout << TensorSum<int>(output, tensor_size); | |||
| break; | |||
| default: | |||
| std::cout << "unsupported type:" << type; | |||
| break; | |||
| } | |||
| std::cout << std::endl; | |||
| } | |||
| return true; | |||
| }; | |||
| @@ -50,6 +50,15 @@ struct MS_API CheckTensor { | |||
| std::vector<float> data; | |||
| }; | |||
| template <typename T> | |||
| T TensorSum(void *data, int size) { | |||
| T *typed_data = reinterpret_cast<T *>(data); | |||
| T sum = static_cast<T>(0); | |||
| for (int i = 0; i < size; i++) { | |||
| sum += typed_data[i]; | |||
| } | |||
| return sum; | |||
| } | |||
| class MS_API NetTrainFlags : public virtual FlagParser { | |||
| public: | |||
| NetTrainFlags() { | |||
| @@ -67,6 +76,7 @@ class MS_API NetTrainFlags : public virtual FlagParser { | |||
| AddFlag(&NetTrainFlags::accuracy_threshold_, "accuracyThreshold", "Threshold of accuracy", 0.5); | |||
| AddFlag(&NetTrainFlags::layer_checksum_, "layerCheckSum", "layer output checksum print (debug)", false); | |||
| AddFlag(&NetTrainFlags::enable_fp16_, "enableFp16", "Enable float16", false); | |||
| AddFlag(&NetTrainFlags::loss_name_, "lossName", "loss layer name", ""); | |||
| } | |||
| ~NetTrainFlags() override = default; | |||
| @@ -98,6 +108,7 @@ class MS_API NetTrainFlags : public virtual FlagParser { | |||
| std::string resize_dims_in_ = ""; | |||
| bool layer_checksum_ = false; | |||
| std::vector<std::vector<int64_t>> resize_dims_; | |||
| std::string loss_name_ = ""; | |||
| }; | |||
| class MS_API NetTrain { | |||
| @@ -72,6 +72,7 @@ | |||
| #include "ops/stack.h" | |||
| #include "ops/tanh.h" | |||
| #include "ops/sparse_softmax_cross_entropy_with_logits.h" | |||
| #include "ops/grad/resize_grad.h" | |||
| using mindspore::ops::kNameAdd; | |||
| using mindspore::ops::kNameAdder; | |||
| @@ -140,6 +141,8 @@ constexpr auto kNameSlice = "Slice"; | |||
| constexpr auto kNameAvgPoolGradGpu = "AvgPoolGradGpu"; | |||
| constexpr auto kNameAvgPoolGradCpu = "AvgPoolGradCpu"; | |||
| constexpr auto kNameTanhGrad = "TanhGrad"; | |||
| constexpr auto kNameResizeBilinearGrad = "ResizeBilinearGrad"; | |||
| constexpr auto kNameResizeNearestNeighborGrad = "ResizeNearestNeighborGrad"; | |||
| std::map<std::string, mindspore::ActivationType> activation_map = {{ops::kNameElu, mindspore::ELU}, | |||
| {ops::kNameGeLU, mindspore::GELU}, | |||
| @@ -489,6 +492,32 @@ int MoveAttrSlice(const CNodePtr &cnode) { | |||
| value_node->set_value(dst_prim); | |||
| return lite::RET_OK; | |||
| } | |||
| int MoveAttrMapResizeGrad(const CNodePtr &cnode) { | |||
| MS_ASSERT(cnode != nullptr); | |||
| auto value_node = cnode->input(0)->cast<ValueNodePtr>(); | |||
| MS_ASSERT(value_node != nullptr); | |||
| auto src_prim = GetValueNode<PrimitivePtr>(value_node); | |||
| if (src_prim == nullptr) { | |||
| MS_LOG(ERROR) << "value node is invalid."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| auto dst_prim = std::make_shared<ops::ResizeGrad>(); | |||
| MS_ASSERT(dst_prim != nullptr); | |||
| if (src_prim->name() == kNameResizeBilinearGrad) { | |||
| dst_prim->set_method(ResizeMethod::LINEAR); | |||
| } else if (src_prim->name() == kNameResizeNearestNeighborGrad) { | |||
| dst_prim->set_method(ResizeMethod::NEAREST); | |||
| } else { | |||
| MS_LOG(ERROR) << "Resize grad method " << src_prim->name() << "is not supported"; | |||
| return lite::RET_ERROR; | |||
| } | |||
| auto align_corners = GetValue<bool>(src_prim->GetAttr(ops::kAlignCorners)); | |||
| dst_prim->set_align_corners(align_corners); | |||
| value_node->set_value(dst_prim); | |||
| return lite::RET_OK; | |||
| } | |||
| } // namespace | |||
| bool PrimitiveAdjustPass::Run(const FuncGraphPtr &func_graph) { | |||
| @@ -591,5 +620,7 @@ REGIST_PRIMITIVE_ADJUST(kNameTile, MoveAttrMapCommon<ops::TileFusion>) | |||
| REGIST_PRIMITIVE_ADJUST(kNameTopK, MoveAttrMapCommon<ops::TopKFusion>) | |||
| REGIST_PRIMITIVE_ADJUST(kNameSparseSoftmaxCrossEntropyWithLogits, | |||
| MoveAttrMapCommon<ops::SparseSoftmaxCrossEntropyWithLogits>) | |||
| REGIST_PRIMITIVE_ADJUST(kNameResizeBilinearGrad, MoveAttrMapResizeGrad) | |||
| REGIST_PRIMITIVE_ADJUST(kNameResizeNearestNeighborGrad, MoveAttrMapResizeGrad) | |||
| } // namespace opt | |||
| } // namespace mindspore | |||