|
- package sqs
-
- import (
- "crypto/md5"
- "encoding/hex"
- "fmt"
- "strings"
-
- "github.com/aws/aws-sdk-go/aws"
- "github.com/aws/aws-sdk-go/aws/awserr"
- "github.com/aws/aws-sdk-go/aws/request"
- )
-
- var (
- errChecksumMissingBody = fmt.Errorf("cannot compute checksum. missing body")
- errChecksumMissingMD5 = fmt.Errorf("cannot verify checksum. missing response MD5")
- )
-
- func setupChecksumValidation(r *request.Request) {
- if aws.BoolValue(r.Config.DisableComputeChecksums) {
- return
- }
-
- switch r.Operation.Name {
- case opSendMessage:
- r.Handlers.Unmarshal.PushBack(verifySendMessage)
- case opSendMessageBatch:
- r.Handlers.Unmarshal.PushBack(verifySendMessageBatch)
- case opReceiveMessage:
- r.Handlers.Unmarshal.PushBack(verifyReceiveMessage)
- }
- }
-
- func verifySendMessage(r *request.Request) {
- if r.DataFilled() && r.ParamsFilled() {
- in := r.Params.(*SendMessageInput)
- out := r.Data.(*SendMessageOutput)
- err := checksumsMatch(in.MessageBody, out.MD5OfMessageBody)
- if err != nil {
- setChecksumError(r, err.Error())
- }
- }
- }
-
- func verifySendMessageBatch(r *request.Request) {
- if r.DataFilled() && r.ParamsFilled() {
- entries := map[string]*SendMessageBatchResultEntry{}
- ids := []string{}
-
- out := r.Data.(*SendMessageBatchOutput)
- for _, entry := range out.Successful {
- entries[*entry.Id] = entry
- }
-
- in := r.Params.(*SendMessageBatchInput)
- for _, entry := range in.Entries {
- if e, ok := entries[*entry.Id]; ok {
- if err := checksumsMatch(entry.MessageBody, e.MD5OfMessageBody); err != nil {
- ids = append(ids, *e.MessageId)
- }
- }
- }
- if len(ids) > 0 {
- setChecksumError(r, "invalid messages: %s", strings.Join(ids, ", "))
- }
- }
- }
-
- func verifyReceiveMessage(r *request.Request) {
- if r.DataFilled() && r.ParamsFilled() {
- ids := []string{}
- out := r.Data.(*ReceiveMessageOutput)
- for i, msg := range out.Messages {
- err := checksumsMatch(msg.Body, msg.MD5OfBody)
- if err != nil {
- if msg.MessageId == nil {
- if r.Config.Logger != nil {
- r.Config.Logger.Log(fmt.Sprintf(
- "WARN: SQS.ReceiveMessage failed checksum request id: %s, message %d has no message ID.",
- r.RequestID, i,
- ))
- }
- continue
- }
-
- ids = append(ids, *msg.MessageId)
- }
- }
- if len(ids) > 0 {
- setChecksumError(r, "invalid messages: %s", strings.Join(ids, ", "))
- }
- }
- }
-
- func checksumsMatch(body, expectedMD5 *string) error {
- if body == nil {
- return errChecksumMissingBody
- } else if expectedMD5 == nil {
- return errChecksumMissingMD5
- }
-
- msum := md5.Sum([]byte(*body))
- sum := hex.EncodeToString(msum[:])
- if sum != *expectedMD5 {
- return fmt.Errorf("expected MD5 checksum '%s', got '%s'", *expectedMD5, sum)
- }
-
- return nil
- }
-
- func setChecksumError(r *request.Request, format string, args ...interface{}) {
- r.Retryable = aws.Bool(true)
- r.Error = awserr.New("InvalidChecksum", fmt.Sprintf(format, args...), nil)
- }
|