diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/trt/trt_utils.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/trt/trt_utils.h index 93f5e564bc..579e11d859 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/trt/trt_utils.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/trt/trt_utils.h @@ -38,7 +38,7 @@ class TrtUtils { static std::map type_list = {{nvinfer1::DataType::kFLOAT, TypeId::kNumberTypeFloat32}, {nvinfer1::DataType::kHALF, TypeId::kNumberTypeFloat16}, {nvinfer1::DataType::kINT8, TypeId::kNumberTypeInt8}, - {nvinfer1::DataType::kINT32, TypeId::kNumberTypeInt}}; + {nvinfer1::DataType::kINT32, TypeId::kNumberTypeInt32}}; auto iter = type_list.find(trt_dtype); if (iter == type_list.end()) { @@ -51,7 +51,8 @@ class TrtUtils { static std::map type_list = {{TypeId::kNumberTypeFloat32, nvinfer1::DataType::kFLOAT}, {TypeId::kNumberTypeFloat16, nvinfer1::DataType::kHALF}, {TypeId::kNumberTypeInt8, nvinfer1::DataType::kINT8}, - {TypeId::kNumberTypeInt, nvinfer1::DataType::kINT32}}; + {TypeId::kNumberTypeInt, nvinfer1::DataType::kINT32}, + {TypeId::kNumberTypeInt32, nvinfer1::DataType::kINT32}}; auto iter = type_list.find(ms_dtype); if (iter == type_list.end()) { MS_LOG(EXCEPTION) << "data type not support: " << ms_dtype; @@ -69,7 +70,7 @@ class TrtUtils { return trt_dims; } - static nvinfer1::Dims TrtDimsToMsDims(const ShapeVector &ms_shape, bool ignore_batch_dim = false) { + static nvinfer1::Dims MsDimsToTrtDims(const ShapeVector &ms_shape, bool ignore_batch_dim = false) { nvinfer1::Dims trt_dims; size_t offset = ignore_batch_dim ? 1 : 0; for (size_t i = offset; i < ms_shape.size(); ++i) { diff --git a/mindspore/ccsrc/backend/optimizer/trt_pass/trt_op_converter.cc b/mindspore/ccsrc/backend/optimizer/trt_pass/trt_op_converter.cc index 75cb69b4b0..ca416467b1 100644 --- a/mindspore/ccsrc/backend/optimizer/trt_pass/trt_op_converter.cc +++ b/mindspore/ccsrc/backend/optimizer/trt_pass/trt_op_converter.cc @@ -24,6 +24,19 @@ namespace mindspore { namespace opt { namespace { +nvinfer1::ITensor *ToTensor(LayerInput *input, const std::vector &shape, + std::shared_ptr context) { + MS_EXCEPTION_IF_NULL(context); + if (input->IsTensor()) { + return input->tensor(); + } + + const nvinfer1::Dims &dim = TrtUtils::MsDimsToTrtDims(shape, false); + auto *const_layer = context->network()->addConstant(dim, *input->weight()); + MS_EXCEPTION_IF_NULL(const_layer); + return const_layer->getOutput(0); +} + ConvertResult AddReshapeLayer(AnfNodePtr node, std::shared_ptr context) { std::vector inputs; bool ret = context->LoadLayerInput(node, &inputs); @@ -34,15 +47,7 @@ ConvertResult AddReshapeLayer(AnfNodePtr node, std::shared_ptrnetwork()->addShuffle(*inputs[0].tensor()); MS_EXCEPTION_IF_NULL(layer); - - const auto &input_shape = AnfAlgo::GetPrevNodeOutputInferShape(node, 0); const auto &output_shape = AnfAlgo::GetOutputInferShape(node, 0); - if (input_shape[0] != output_shape[0]) { - MS_LOG(ERROR) << "Reshape does not support modify batch size. Input batch size: " << input_shape[0] - << "Output batch size: " << output_shape[0]; - return {false, {}}; - } - const nvinfer1::Dims &dims = TrtUtils::MsDimsToTrtDims(output_shape, false); layer->setReshapeDimensions(dims); @@ -62,7 +67,6 @@ ConvertResult AddElementLayer(AnfNodePtr node, std::shared_ptr &x2_shape = AnfAlgo::GetPrevNodeOutputInferShape(node, 1); const std::vector &y_shape = AnfAlgo::GetOutputInferShape(node, 0); - // Keep to output auto Broadcast = [&context, &y_shape](nvinfer1::ITensor *tensor, const std::vector &x_shape) { if (x_shape.size() == y_shape.size()) { return tensor; @@ -88,8 +92,8 @@ ConvertResult AddElementLayer(AnfNodePtr node, std::shared_ptrgetOutput(0); }; - auto *x1 = Broadcast(inputs[0].tensor(), x1_shape); - auto *x2 = Broadcast(inputs[1].tensor(), x2_shape); + auto *x1 = Broadcast(ToTensor(&inputs[0], x1_shape, context), x1_shape); + auto *x2 = Broadcast(ToTensor(&inputs[1], x2_shape, context), x2_shape); auto *layer = context->network()->addElementWise(*x1, *x2, op_type); MS_EXCEPTION_IF_NULL(layer); @@ -142,6 +146,80 @@ ConvertResult AddActivationLayer(AnfNodePtr node, std::shared_ptrgetOutput(0))}}; } + +ConvertResult AddUnaryLayer(AnfNodePtr node, std::shared_ptr context, + nvinfer1::UnaryOperation op_type) { + std::vector inputs; + bool ret = context->LoadLayerInput(node, &inputs); + if (!ret || inputs.size() != 1) { + MS_LOG(ERROR) << "Input num not match: " << inputs.size() << ", with 2 expected."; + return {false, {}}; + } + + auto *layer = context->network()->addUnary(*inputs[0].tensor(), op_type); + MS_EXCEPTION_IF_NULL(layer); + + return {true, {LayerInput(layer->getOutput(0))}}; +} + +ConvertResult addReduceLayer(AnfNodePtr node, std::shared_ptr context, + nvinfer1::ReduceOperation op_type) { + std::vector inputs; + bool ret = context->LoadLayerInput(node, &inputs); + if (!ret || inputs.size() != 1) { + MS_LOG(ERROR) << "Input num not match: " << inputs.size() << ", with 2 expected."; + return {false, {}}; + } + + // Calculate reduce axes bitmask + const std::vector &input_shape = AnfAlgo::GetPrevNodeOutputInferShape(node, 0); + const ValuePtr &value = AnfAlgo::GetCNodePrimitive(node)->GetAttr("axis"); + uint32_t reduce_axes = 0; + if (value->isa() || value->isa()) { + const auto &axis = AnfAlgo::GetNodeAttr>(node, "axis"); + for (size_t i = 0; i < axis.size(); i++) { + int offset = axis[i] >= 0 ? LongToInt(axis[i]) : LongToInt(axis[i] + input_shape.size()); + reduce_axes |= 1UL << offset; + } + } else { + const auto &axis = AnfAlgo::GetNodeAttr(node, "axis"); + int offset = axis >= 0 ? LongToInt(axis) : LongToInt(axis + input_shape.size()); + reduce_axes = 1UL << offset; + } + + // Tensor-RT do not support reduce with no dimensions. + // Skip reduce operator if reduce_axes == 0 + if (reduce_axes == 0) { + MS_LOG(WARNING) << "No dimension be be reduced. " << node->DebugString(); + return {true, {LayerInput(inputs[0].tensor())}}; + } + + bool keep_dims = AnfAlgo::GetNodeAttr(node, "keep_dims"); + // Tensor-RT do not support reduce all dimensions with keep_dims == false. + // Reduce with keep_dims = true, add apply reshape latter. + bool post_reshape = false; + if (keep_dims == false && (reduce_axes == (1UL << input_shape.size()) - 1)) { + keep_dims = true; + post_reshape = true; + } + + nvinfer1::IReduceLayer *layer = context->network()->addReduce(*inputs[0].tensor(), op_type, reduce_axes, keep_dims); + MS_EXCEPTION_IF_NULL(layer); + + if (post_reshape) { + nvinfer1::IShuffleLayer *reshape_layer = context->network()->addShuffle(*layer->getOutput(0)); + MS_EXCEPTION_IF_NULL(reshape_layer); + + nvinfer1::Dims dim; + dim.nbDims = 1; + dim.d[1] = 1; + reshape_layer->setReshapeDimensions(dim); + + return {true, {LayerInput(reshape_layer->getOutput(0))}}; + } + + return {true, {LayerInput(layer->getOutput(0))}}; +} } // namespace // Register operator converter from AnfNode to trt layer: `OPNAME` should keep the same as primitive definition. @@ -195,6 +273,7 @@ MS_TRT_CONVERTER_FUNC_REG(Add) { return AddElementLayer(node, context, nvinfer1: MS_TRT_CONVERTER_FUNC_REG(Sub) { return AddElementLayer(node, context, nvinfer1::ElementWiseOperation::kSUB); } MS_TRT_CONVERTER_FUNC_REG(Mul) { return AddElementLayer(node, context, nvinfer1::ElementWiseOperation::kPROD); } MS_TRT_CONVERTER_FUNC_REG(Div) { return AddElementLayer(node, context, nvinfer1::ElementWiseOperation::kDIV); } +MS_TRT_CONVERTER_FUNC_REG(RealDiv) { return AddElementLayer(node, context, nvinfer1::ElementWiseOperation::kDIV); } MS_TRT_CONVERTER_FUNC_REG(Pow) { return AddElementLayer(node, context, nvinfer1::ElementWiseOperation::kPOW); } MS_TRT_CONVERTER_FUNC_REG(Maximum) { return AddElementLayer(node, context, nvinfer1::ElementWiseOperation::kMAX); } MS_TRT_CONVERTER_FUNC_REG(Minimum) { return AddElementLayer(node, context, nvinfer1::ElementWiseOperation::kMIN); } @@ -202,6 +281,33 @@ MS_TRT_CONVERTER_FUNC_REG(FloorDiv) { return AddElementLayer(node, context, nvinfer1::ElementWiseOperation::kFLOOR_DIV); } +// Unary operators +MS_TRT_CONVERTER_FUNC_REG(Exp) { return AddUnaryLayer(node, context, nvinfer1::UnaryOperation::kEXP); } +MS_TRT_CONVERTER_FUNC_REG(Log) { return AddUnaryLayer(node, context, nvinfer1::UnaryOperation::kLOG); } +MS_TRT_CONVERTER_FUNC_REG(Sqrt) { return AddUnaryLayer(node, context, nvinfer1::UnaryOperation::kSQRT); } +MS_TRT_CONVERTER_FUNC_REG(Reciprocal) { return AddUnaryLayer(node, context, nvinfer1::UnaryOperation::kRECIP); } +MS_TRT_CONVERTER_FUNC_REG(Abs) { return AddUnaryLayer(node, context, nvinfer1::UnaryOperation::kABS); } +MS_TRT_CONVERTER_FUNC_REG(Neg) { return AddUnaryLayer(node, context, nvinfer1::UnaryOperation::kNEG); } +MS_TRT_CONVERTER_FUNC_REG(Sin) { return AddUnaryLayer(node, context, nvinfer1::UnaryOperation::kSIN); } +MS_TRT_CONVERTER_FUNC_REG(COS) { return AddUnaryLayer(node, context, nvinfer1::UnaryOperation::kCOS); } +MS_TRT_CONVERTER_FUNC_REG(Tan) { return AddUnaryLayer(node, context, nvinfer1::UnaryOperation::kTAN); } +MS_TRT_CONVERTER_FUNC_REG(Sinh) { return AddUnaryLayer(node, context, nvinfer1::UnaryOperation::kSINH); } +MS_TRT_CONVERTER_FUNC_REG(Cosh) { return AddUnaryLayer(node, context, nvinfer1::UnaryOperation::kCOSH); } +MS_TRT_CONVERTER_FUNC_REG(Asin) { return AddUnaryLayer(node, context, nvinfer1::UnaryOperation::kASIN); } +MS_TRT_CONVERTER_FUNC_REG(Acos) { return AddUnaryLayer(node, context, nvinfer1::UnaryOperation::kACOS); } +MS_TRT_CONVERTER_FUNC_REG(Atan) { return AddUnaryLayer(node, context, nvinfer1::UnaryOperation::kATAN); } +MS_TRT_CONVERTER_FUNC_REG(Asinh) { return AddUnaryLayer(node, context, nvinfer1::UnaryOperation::kASINH); } +MS_TRT_CONVERTER_FUNC_REG(Acosh) { return AddUnaryLayer(node, context, nvinfer1::UnaryOperation::kACOSH); } +MS_TRT_CONVERTER_FUNC_REG(Ceil) { return AddUnaryLayer(node, context, nvinfer1::UnaryOperation::kCEIL); } +MS_TRT_CONVERTER_FUNC_REG(Floor) { return AddUnaryLayer(node, context, nvinfer1::UnaryOperation::kFLOOR); } + +// Reduce operators +MS_TRT_CONVERTER_FUNC_REG(ReduceSum) { return addReduceLayer(node, context, nvinfer1::ReduceOperation::kSUM); } +MS_TRT_CONVERTER_FUNC_REG(ReduceMean) { return addReduceLayer(node, context, nvinfer1::ReduceOperation::kAVG); } +MS_TRT_CONVERTER_FUNC_REG(ReduceMax) { return addReduceLayer(node, context, nvinfer1::ReduceOperation::kMAX); } +MS_TRT_CONVERTER_FUNC_REG(ReduceMin) { return addReduceLayer(node, context, nvinfer1::ReduceOperation::kMIN); } +MS_TRT_CONVERTER_FUNC_REG(ReduceProd) { return addReduceLayer(node, context, nvinfer1::ReduceOperation::kPROD); } + // Pooling operators. MS_TRT_CONVERTER_FUNC_REG(AvgPool) { return AddPoolingLayer(node, context, nvinfer1::PoolingType::kAVERAGE); } MS_TRT_CONVERTER_FUNC_REG(MaxPool) { return AddPoolingLayer(node, context, nvinfer1::PoolingType::kMAX); } @@ -304,6 +410,45 @@ MS_TRT_CONVERTER_FUNC_REG(MatMul) { } } +MS_TRT_CONVERTER_FUNC_REG(BatchMatMul) { + std::vector inputs; + bool ret = context->LoadLayerInput(node, &inputs); + if (!ret || inputs.size() != 2) { + MS_LOG(ERROR) << "Input num not match: " << inputs.size() << ", with 2 expected."; + return {false, {}}; + } + + const auto &transpose_a = AnfAlgo::GetNodeAttr(node, "transpose_a"); + const auto &transpose_b = AnfAlgo::GetNodeAttr(node, "transpose_b"); + const auto &trt_transpose1 = transpose_a ? nvinfer1::MatrixOperation::kTRANSPOSE : nvinfer1::MatrixOperation::kNONE; + const auto &trt_transpose2 = transpose_b ? nvinfer1::MatrixOperation::kTRANSPOSE : nvinfer1::MatrixOperation::kNONE; + + std::vector shape1 = AnfAlgo::GetPrevNodeOutputInferShape(node, 0); + std::vector shape2 = AnfAlgo::GetPrevNodeOutputInferShape(node, 1); + + auto SwapLastDims = [](std::vector shape, const bool &transpose) { + if (shape.size() < 2) { + MS_LOG(EXCEPTION) << "Operation not support: input rank should >= 2"; + } + + if (!transpose) { + return shape; + } + + size_t tmp = shape[shape.size() - 2]; + shape[shape.size() - 2] = shape[shape.size() - 1]; + shape[shape.size() - 1] = tmp; + return shape; + }; + + nvinfer1::ITensor *tensor1 = ToTensor(&inputs[0], SwapLastDims(shape1, transpose_a), context); + nvinfer1::ITensor *tensor2 = ToTensor(&inputs[1], SwapLastDims(shape2, transpose_b), context); + auto *layer = context->network()->addMatrixMultiply(*tensor1, trt_transpose1, *tensor2, trt_transpose2); + MS_EXCEPTION_IF_NULL(layer); + + return {true, {LayerInput(layer->getOutput(0))}}; +} + MS_TRT_CONVERTER_FUNC_REG(BiasAdd) { std::vector inputs; bool ret = context->LoadLayerInput(node, &inputs); @@ -321,7 +466,7 @@ MS_TRT_CONVERTER_FUNC_REG(BiasAdd) { return {false, {}}; } - // Convert Weight to ITensor which + // Convert Weight to ITensor nvinfer1::Dims unsqueeze_bias_dims; unsqueeze_bias_dims.nbDims = x_shape.size(); std::fill(unsqueeze_bias_dims.d, unsqueeze_bias_dims.d + unsqueeze_bias_dims.nbDims, 1); @@ -335,10 +480,9 @@ MS_TRT_CONVERTER_FUNC_REG(BiasAdd) { return {true, {LayerInput(layer->getOutput(0))}}; } +// NoOp MS_TRT_CONVERTER_FUNC_REG(Reshape) { return AddReshapeLayer(node, context); } - MS_TRT_CONVERTER_FUNC_REG(ExpandDims) { return AddReshapeLayer(node, context); } - MS_TRT_CONVERTER_FUNC_REG(Squeeze) { return AddReshapeLayer(node, context); } MS_TRT_CONVERTER_FUNC_REG(BatchNorm) { @@ -466,5 +610,207 @@ MS_TRT_CONVERTER_FUNC_REG(Conv2DBackpropInput) { return {true, {LayerInput(layer->getOutput(0))}}; } + +MS_TRT_CONVERTER_FUNC_REG(Slice) { + std::vector inputs; + bool ret = context->LoadLayerInput(node, &inputs); + if (!ret || inputs.size() != 1 || !inputs[0].IsTensor()) { + MS_LOG(ERROR) << "Input num not match: " << inputs.size() << ", with 1 expected."; + return {false, {}}; + } + + const auto &begin = AnfAlgo::GetNodeAttr>(node, "begin"); + const auto &size = AnfAlgo::GetNodeAttr>(node, "size"); + + nvinfer1::Dims trt_start = TrtUtils::MsDimsToTrtDims(begin, false); + nvinfer1::Dims trt_size = TrtUtils::MsDimsToTrtDims(size, false); + nvinfer1::Dims trt_stride; + for (int32_t i = 0; i < trt_start.nbDims; i++) { + trt_stride.d[trt_stride.nbDims++] = 1; + } + + auto *layer = context->network()->addSlice(*inputs[0].tensor(), trt_start, trt_size, trt_stride); + MS_EXCEPTION_IF_NULL(layer); + + return {true, {LayerInput(layer->getOutput(0))}}; +} + +MS_TRT_CONVERTER_FUNC_REG(Transpose) { + std::vector inputs; + bool ret = context->LoadLayerInput(node, &inputs); + if (!ret || inputs.size() != 1 || !inputs[0].IsTensor()) { + MS_LOG(ERROR) << "Input num not match: " << inputs.size() << ", with 1 expected."; + return {false, {}}; + } + + const auto &perm = AnfAlgo::GetNodeAttr>(node, "perm"); + nvinfer1::Permutation trt_perm; + for (size_t i = 0; i < perm.size(); i++) { + trt_perm.order[i] = LongToInt(perm[i]); + } + + auto *layer = context->network()->addShuffle(*inputs[0].tensor()); + MS_EXCEPTION_IF_NULL(layer); + layer->setFirstTranspose(trt_perm); + + return {true, {LayerInput(layer->getOutput(0))}}; +} + +MS_TRT_CONVERTER_FUNC_REG(Softmax) { + std::vector inputs; + bool ret = context->LoadLayerInput(node, &inputs); + if (!ret || inputs.size() != 1 || !inputs[0].IsTensor()) { + MS_LOG(ERROR) << "Input num not match: " << inputs.size() << ", with 1 expected."; + return {false, {}}; + } + + const std::vector &input_shape = AnfAlgo::GetPrevNodeOutputInferShape(node, 0); + const ValuePtr &value = AnfAlgo::GetCNodePrimitive(node)->GetAttr("axis"); + uint32_t reduce_axes = 0; + if (value->isa() || value->isa()) { + const auto &axis = AnfAlgo::GetNodeAttr>(node, "axis"); + if (axis.size() != 1) { + MS_LOG(ERROR) << "Only one axis can be set. Axis size" << axis.size(); + return {false, {}}; + } + int offset = axis[0] >= 0 ? LongToInt(axis[0]) : LongToInt(axis[0] + input_shape.size()); + reduce_axes = 1U << offset; + } else { + const auto &axis = AnfAlgo::GetNodeAttr(node, "axis"); + int offset = axis >= 0 ? LongToInt(axis) : LongToInt(axis + input_shape.size()); + reduce_axes = 1UL << offset; + } + + auto *layer = context->network()->addSoftMax(*inputs[0].tensor()); + MS_EXCEPTION_IF_NULL(layer); + layer->setAxes(reduce_axes); + return {true, {LayerInput(layer->getOutput(0))}}; +} + +MS_TRT_CONVERTER_FUNC_REG(LogSoftmax) { + std::vector inputs; + bool ret = context->LoadLayerInput(node, &inputs); + if (!ret || inputs.size() != 1 || !inputs[0].IsTensor()) { + MS_LOG(ERROR) << "Input num not match: " << inputs.size() << ", with 1 expected."; + return {false, {}}; + } + + const std::vector &input_shape = AnfAlgo::GetPrevNodeOutputInferShape(node, 0); + const auto &axis = AnfAlgo::GetNodeAttr(node, "axis"); + int offset = axis >= 0 ? LongToInt(axis) : LongToInt(axis + input_shape.size()); + uint32_t reduce_axes = 1UL << offset; + + auto *softmax_layer = context->network()->addSoftMax(*inputs[0].tensor()); + MS_EXCEPTION_IF_NULL(softmax_layer); + softmax_layer->setAxes(reduce_axes); + + auto *log_layer = context->network()->addUnary(*softmax_layer->getOutput(0), nvinfer1::UnaryOperation::kLOG); + MS_EXCEPTION_IF_NULL(log_layer); + + return {true, {LayerInput(log_layer->getOutput(0))}}; +} + +MS_TRT_CONVERTER_FUNC_REG(Gather) { + std::vector inputs; + bool ret = context->LoadLayerInput(node, &inputs); + if (!ret || inputs.size() != 2) { + MS_LOG(ERROR) << "Input num not match: " << inputs.size() << ", with 2 expected."; + return {false, {}}; + } + + const std::vector &input_shape = AnfAlgo::GetPrevNodeOutputInferShape(node, 0); + auto axis = AnfAlgo::GetNodeAttr(node, "axis"); + axis = axis >= 0 ? axis : axis + input_shape.size(); + + nvinfer1::ITensor *input = ToTensor(&inputs[0], input_shape, context); + const std::vector &indices_shape = AnfAlgo::GetPrevNodeOutputInferShape(node, 1); + nvinfer1::ITensor *indices = ToTensor(&inputs[1], indices_shape, context); + + auto *layer = context->network()->addGather(*input, *indices, LongToInt(axis)); + MS_EXCEPTION_IF_NULL(layer); + + return {true, {LayerInput(layer->getOutput(0))}}; +} + +MS_TRT_CONVERTER_FUNC_REG(Cast) { + std::vector inputs; + bool ret = context->LoadLayerInput(node, &inputs); + if (!ret || inputs.size() != 1 || !inputs[0].IsTensor()) { + MS_LOG(ERROR) << "Get inputs failed. Input num: " << inputs.size(); + return {false, {}}; + } + + const TypeId &dst_type = AnfAlgo::GetOutputInferDataType(node, 0); + auto trt_type = TrtUtils::MsDtypeToTrtDtype(dst_type); + auto *layer = context->network()->addIdentity(*inputs[0].tensor()); + layer->setOutputType(0, trt_type); + return {true, {LayerInput(layer->getOutput(0))}}; +} + +MS_TRT_CONVERTER_FUNC_REG(LayerNorm) { + std::vector inputs; + bool ret = context->LoadLayerInput(node, &inputs); + if (!ret || inputs.size() != 3 || !inputs[0].IsTensor() || !inputs[1].IsWeight() || !inputs[2].IsWeight()) { + MS_LOG(ERROR) << "Get inputs failed. Input num: " << inputs.size(); + return {false, {}}; + } + + // Calculate reduce axes + const std::vector &input_shape = AnfAlgo::GetOutputInferShape(node, 0); + auto begin_norm_axis = AnfAlgo::GetNodeAttr(node, "begin_norm_axis"); + begin_norm_axis = begin_norm_axis >= 0 ? begin_norm_axis : begin_norm_axis + input_shape.size(); + uint32_t reduce_axes = 0; + for (size_t i = LongToSize(begin_norm_axis); i < input_shape.size(); i++) { + reduce_axes |= 1UL << i; + } + + // Reshape gamma and beta for broadcast + auto begin_params_axis = AnfAlgo::GetNodeAttr(node, "begin_params_axis"); + begin_params_axis = begin_params_axis >= 0 ? begin_params_axis : begin_params_axis + input_shape.size(); + std::vector param_shape = input_shape; + for (size_t j = 0; j < LongToSize(begin_params_axis); j++) { + param_shape[j] = 1; + } + + auto epsilon = AnfAlgo::GetNodeAttr(node, "epsilon"); + std::shared_ptr weight = context->CreateTempWeight(kNumberTypeFloat32, {1}); + auto value = static_cast(weight->data_c()); + value[0] = epsilon; + nvinfer1::Dims dim; + dim.nbDims = SizeToInt(input_shape.size()); + for (size_t i = 0; i < input_shape.size(); i++) { + dim.d[i] = 1; + } + auto *epsilon_layer = context->network()->addConstant(dim, nvinfer1::Weights{nvinfer1::DataType::kFLOAT, value, 1}); + MS_EXCEPTION_IF_NULL(epsilon_layer); + + // y = (x - mean) / sqrt(var) * gamma + beta + auto *mean = context->network()->addReduce(*inputs[0].tensor(), nvinfer1::ReduceOperation::kAVG, reduce_axes, true); + MS_EXCEPTION_IF_NULL(mean); + auto *sub = + context->network()->addElementWise(*inputs[0].tensor(), *mean->getOutput(0), nvinfer1::ElementWiseOperation::kSUB); + MS_EXCEPTION_IF_NULL(sub); + auto *pow = + context->network()->addElementWise(*sub->getOutput(0), *sub->getOutput(0), nvinfer1::ElementWiseOperation::kPROD); + MS_EXCEPTION_IF_NULL(pow); + auto *var = context->network()->addReduce(*pow->getOutput(0), nvinfer1::ReduceOperation::kAVG, reduce_axes, true); + MS_EXCEPTION_IF_NULL(var); + auto *var_epsilon = context->network()->addElementWise(*var->getOutput(0), *epsilon_layer->getOutput(0), + nvinfer1::ElementWiseOperation::kSUM); + MS_EXCEPTION_IF_NULL(var_epsilon); + auto *std = context->network()->addUnary(*var_epsilon->getOutput(0), nvinfer1::UnaryOperation::kSQRT); + MS_EXCEPTION_IF_NULL(std); + auto *div = + context->network()->addElementWise(*sub->getOutput(0), *std->getOutput(0), nvinfer1::ElementWiseOperation::kDIV); + MS_EXCEPTION_IF_NULL(div); + auto *mul = context->network()->addElementWise(*div->getOutput(0), *ToTensor(&inputs[1], param_shape, context), + nvinfer1::ElementWiseOperation::kPROD); + MS_EXCEPTION_IF_NULL(mul); + auto *add = context->network()->addElementWise(*mul->getOutput(0), *ToTensor(&inputs[2], param_shape, context), + nvinfer1::ElementWiseOperation::kSUM); + MS_EXCEPTION_IF_NULL(add); + + return {true, {LayerInput(add->getOutput(0))}}; +} } // namespace opt } // namespace mindspore