|
|
|
@@ -26,16 +26,19 @@ |
|
|
|
namespace mindspore { |
|
|
|
namespace ops { |
|
|
|
constexpr auto kNameAttention = "Attention"; |
|
|
|
// Attention MultiHeadAttention |
|
|
|
/// \brief MultiHead-Attention op in MindIR. |
|
|
|
class MS_CORE_API Attention : public PrimitiveC { |
|
|
|
public: |
|
|
|
/// \brief Constructor. |
|
|
|
Attention() : PrimitiveC(kNameAttention) { |
|
|
|
InitIOName( |
|
|
|
{"q", "k", "v", "weight_q", "weight_k", "weight_v", "weight_o", "bias_q", "bias_k", "bias_v", "bias_o", "mask"}, |
|
|
|
{"output"}); |
|
|
|
} |
|
|
|
/// \brief Destructor. |
|
|
|
~Attention() override = default; |
|
|
|
MS_DECLARE_PARENT(Attention, PrimitiveC); |
|
|
|
/// \brief Initialize Attention op. |
|
|
|
void Init() {} |
|
|
|
}; |
|
|
|
} // namespace ops |
|
|
|
|