/** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_PREDICT_NODE_UTIL_H #define MINDSPORE_PREDICT_NODE_UTIL_H #include #include #include #include #include "schema/inner/model_generated.h" #include "src/common/common.h" #include "src/common/log_adapter.h" #include "include/errorcode.h" #include "securec/include/securec.h" namespace mindspore { namespace lite { using STATUS = int; STATUS BroadCastQuantParam(schema::MetaGraphT *graphT, const std::unique_ptr &node); inline schema::PrimitiveType GetCNodeTType(const schema::CNodeT &cNodeT) { return cNodeT.primitive->value.type; } inline std::string GetCNodeTTypeName(const schema::CNodeT &cNodeT) { return schema::EnumNamePrimitiveType(GetCNodeTType(cNodeT)); } inline schema::PrimitiveType GetOpType(const schema::CNode &opDef) { return opDef.primitive()->value_type(); } inline std::string GetOpTypeName(const schema::CNode &opDef) { return schema::EnumNamePrimitiveType(GetOpType(opDef)); } std::unordered_map GetNc2NhAxisMap(); std::vector GetInsertOpList(); std::vector GetNhwcOpList(); std::vector GetNhwcDualInputOpList(); std::vector GetNhwcAllInputOpList(); std::vector Getfp32FullOpList(); std::vector GetUint8NhwcOpList(); std::vector GetInt8OpList(); class NodeUtils { public: static STATUS ConvertDims(schema::Format src_format, const std::vector &src_dims, schema::Format dst_format, std::vector *dst_dims); static void SliceData(std::vector &input, int64_t chunk_size, std::vector &output, int64_t begin, int64_t out_dim, int64_t stride); static STATUS SetOutputSliceData(void *data, int64_t data_size, int32_t data_type, std::vector &input_dims, std::vector &begin, std::vector &output_dims, schema::TensorT *output, std::vector &stride); }; enum kTransFilterType { kKCHW2HWCK, // 0 kKCHW2KHWC, kCKHW2KHWC, kCKHW2HWCK, kKCHW2HWKC, kCKHW2HWKC, kHWCK2KCHW, kHWCK2CKHW, kHWKC2KCHW, kHWKC2CKHW, kNHWC2KCHW, // 10 kNHWC2CKHW, kNHWC2HWCK, kKHWC2HWCK, kCHWK2HWCK, kKHWC2CHWK, kCHWK2KHWC, kKHWC2KCHW, kCKHW2KCHW, kCHWK2KCHW, kKCHW2CKHW // 20 }; STATUS GetFilterDim(const std::vector &oriDims, kTransFilterType type, int32_t *filterK, int32_t *filterC, int32_t *filterH, int32_t *filterW); STATUS SetFilterDim(schema::TensorT *tensor, kTransFilterType type, int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW); template static STATUS TransFilterData(schema::TensorT *tensor, kTransFilterType type, int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW) { MS_ASSERT(tensor != nullptr); int count = filterH * filterW * filterC * filterK; if (count <= 0) { MS_LOG(ERROR) << "Dim size invalid"; return RET_ERROR; } std::unique_ptr buf(new (std::nothrow) T[count]); if (buf == nullptr) { MS_LOG(ERROR) << "new buf failed"; return RET_ERROR; } void *originWeightDate = tensor->data.data(); T *weightData = static_cast(originWeightDate); if (weightData == nullptr) { MS_LOG(ERROR) << "weightData is nullptr"; return RET_ERROR; } T *p1Buff = nullptr; T *p2Buff = nullptr; switch (type) { case kCHWK2HWCK: case kCHWK2KHWC: { for (int c = 0; c < filterC; ++c) { for (int h = 0; h < filterH; ++h) { for (int w = 0; w < filterW; ++w) { for (int k = 0; k < filterK; ++k) { p1Buff = weightData + ((c * filterH * filterW * filterK) + (h * filterW * filterK) + (w * filterK) + (k)); if (type == kCHWK2HWCK) { p2Buff = buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); } else if (type == kCHWK2KHWC) { p2Buff = buf.get() + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c)); } *p2Buff = *p1Buff; } } } } } break; case kKHWC2HWCK: { for (int k = 0; k < filterK; ++k) { for (int h = 0; h < filterH; ++h) { for (int w = 0; w < filterW; ++w) { for (int c = 0; c < filterC; ++c) { p1Buff = weightData + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c)); p2Buff = buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); *p2Buff = *p1Buff; } } } } } break; case kKCHW2HWCK: case kKCHW2CKHW: case kKCHW2KHWC: case kKCHW2HWKC: { for (int k = 0; k < filterK; ++k) { for (int c = 0; c < filterC; ++c) { for (int h = 0; h < filterH; ++h) { for (int w = 0; w < filterW; ++w) { p1Buff = weightData + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w)); if (type == kKCHW2HWCK) { p2Buff = buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); } else if (type == kKCHW2KHWC) { p2Buff = buf.get() + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c)); } else if (type == kKCHW2CKHW) { p2Buff = buf.get() + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w)); } else { p2Buff = buf.get() + ((h * filterW * filterK * filterC) + (w * filterK * filterC) + (k * filterC) + (c)); } *p2Buff = *p1Buff; } } } } } break; case kCKHW2HWCK: case kCKHW2KHWC: case kCKHW2HWKC: { for (int c = 0; c < filterC; ++c) { for (int k = 0; k < filterK; ++k) { for (int h = 0; h < filterH; ++h) { for (int w = 0; w < filterW; ++w) { p1Buff = weightData + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w)); if (type == kCKHW2HWCK) { p2Buff = buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); } else if (type == kCKHW2KHWC) { p2Buff = buf.get() + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c)); } else { p2Buff = buf.get() + ((h * filterW * filterK * filterC) + (w * filterK * filterC) + (k * filterC) + (c)); } *p2Buff = *p1Buff; } } } } } break; case kHWCK2KCHW: case kHWCK2CKHW: { for (int h = 0; h < filterH; ++h) { for (int w = 0; w < filterW; ++w) { for (int c = 0; c < filterC; ++c) { for (int k = 0; k < filterK; ++k) { p1Buff = weightData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); if (type == kHWCK2KCHW) { p2Buff = buf.get() + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w)); } else { p2Buff = buf.get() + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w)); } *p2Buff = *p1Buff; } } } } } break; case kHWKC2KCHW: case kHWKC2CKHW: { for (int h = 0; h < filterH; ++h) { for (int w = 0; w < filterW; ++w) { for (int c = 0; c < filterC; ++c) { for (int k = 0; k < filterK; ++k) { p1Buff = weightData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (k * filterC) + (c)); if (type == kHWKC2KCHW) { p2Buff = buf.get() + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w)); } else { p2Buff = buf.get() + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w)); } *p2Buff = *p1Buff; } } } } } break; case kNHWC2HWCK: case kNHWC2KCHW: case kNHWC2CKHW: { for (int k = 0; k < filterK; ++k) { for (int h = 0; h < filterH; ++h) { for (int w = 0; w < filterW; ++w) { for (int c = 0; c < filterC; ++c) { p1Buff = weightData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (k * filterC) + (c)); if (type == kNHWC2HWCK) { p2Buff = buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); } else if (type == kNHWC2CKHW) { p2Buff = buf.get() + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w)); } else { p2Buff = buf.get() + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w)); } *p2Buff = *p1Buff; } } } } } break; case kKHWC2CHWK: { for (int k = 0; k < filterK; ++k) { for (int h = 0; h < filterH; ++h) { for (int w = 0; w < filterW; ++w) { for (int c = 0; c < filterC; ++c) { p1Buff = weightData + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c)); p2Buff = buf.get() + ((c * filterK * filterH * filterW) + (h * filterK * filterW) + (w * filterK) + (k)); *p2Buff = *p1Buff; } } } } } break; default: { MS_LOG(ERROR) << "Unsupported transFilterType: " << type; return RET_ERROR; } } auto ret = ::memcpy_s(tensor->data.data(), count * sizeof(T), buf.get(), count * sizeof(T)); if (ret != EOK) { MS_LOG(ERROR) << "memcpy_s failed: " << ret; return RET_ERROR; } return RET_OK; } template static STATUS TransFilterFormat(schema::TensorT *tensor, kTransFilterType type) { MS_ASSERT(tensor != nullptr); std::vector oriDims = tensor->dims; if (oriDims.size() != (size_t)DIM_DEFAULT_SIZE) { MS_LOG(ERROR) << "Filter dim-num is not supported, dim-num: " << oriDims.size(); return RET_ERROR; } int32_t filterH; int32_t filterW; int32_t filterC; int32_t filterK; auto status = GetFilterDim(oriDims, type, &filterK, &filterC, &filterH, &filterW); if (status != RET_OK) { MS_LOG(ERROR) << "GetFilterDim failed: " << status; return status; } status = SetFilterDim(tensor, type, filterK, filterC, filterH, filterW); if (status != RET_OK) { MS_LOG(ERROR) << "SetFilterDim failed: " << status; return status; } status = TransFilterData(tensor, type, filterK, filterC, filterH, filterW); if (status != RET_OK) { MS_LOG(ERROR) << "TransFilterData failed: " << status; return status; } return RET_OK; } STATUS TransFilterFormat(schema::TensorT *tensor, schema::Format dstFormat); } // namespace lite } // namespace mindspore #endif // MINDSPORE_PREDICT_NODE_UTIL_H