You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

executor_test.go 7.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258
  1. /*
  2. * Licensed to the Apache Software Foundation (ASF) under one or more
  3. * contributor license agreements. See the NOTICE file distributed with
  4. * this work for additional information regarding copyright ownership.
  5. * The ASF licenses this file to You under the Apache License, Version 2.0
  6. * (the "License"); you may not use this file except in compliance with
  7. * the License. You may obtain a copy of the License at
  8. *
  9. * http://www.apache.org/licenses/LICENSE-2.0
  10. *
  11. * Unless required by applicable law or agreed to in writing, software
  12. * distributed under the License is distributed on an "AS IS" BASIS,
  13. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. * See the License for the specific language governing permissions and
  15. * limitations under the License.
  16. */
  17. package executor
  18. import (
  19. "context"
  20. "database/sql"
  21. "encoding/json"
  22. "testing"
  23. "github.com/agiledragon/gomonkey/v2"
  24. "github.com/pkg/errors"
  25. "github.com/stretchr/testify/assert"
  26. "seata.apache.org/seata-go/pkg/datasource/sql/types"
  27. "seata.apache.org/seata-go/pkg/datasource/sql/undo"
  28. serr "seata.apache.org/seata-go/pkg/util/errors"
  29. "seata.apache.org/seata-go/pkg/util/log"
  30. )
  31. type testableBaseExecutor struct {
  32. BaseExecutor
  33. mockCurrentImage *types.RecordImage
  34. }
  35. func (t *testableBaseExecutor) queryCurrentRecords(ctx context.Context, conn *sql.Conn) (*types.RecordImage, error) {
  36. return t.mockCurrentImage, nil
  37. }
  38. func (t *testableBaseExecutor) dataValidationAndGoOn(ctx context.Context, conn *sql.Conn) (bool, error) {
  39. if !undo.UndoConfig.DataValidation {
  40. return true, nil
  41. }
  42. beforeImage := t.sqlUndoLog.BeforeImage
  43. afterImage := t.sqlUndoLog.AfterImage
  44. equals, err := IsRecordsEquals(beforeImage, afterImage)
  45. if err != nil {
  46. return false, err
  47. }
  48. if equals {
  49. log.Infof("Stop rollback because there is no data change between the before data snapshot and the after data snapshot.")
  50. return false, nil
  51. }
  52. currentImage, err := t.queryCurrentRecords(ctx, conn)
  53. if err != nil {
  54. return false, err
  55. }
  56. equals, err = IsRecordsEquals(afterImage, currentImage)
  57. if err != nil {
  58. return false, err
  59. }
  60. if !equals {
  61. equals, err = IsRecordsEquals(beforeImage, currentImage)
  62. if err != nil {
  63. return false, err
  64. }
  65. if equals {
  66. log.Infof("Stop rollback because there is no data change between the before data snapshot and the current data snapshot.")
  67. return false, nil
  68. } else {
  69. oldRowJson, _ := json.Marshal(afterImage.Rows)
  70. newRowJson, _ := json.Marshal(currentImage.Rows)
  71. log.Infof("check dirty data failed, old and new data are not equal, "+
  72. "tableName:[%s], oldRows:[%s],newRows:[%s].", afterImage.TableName, oldRowJson, newRowJson)
  73. return false, serr.New(serr.SQLUndoDirtyError, "has dirty records when undo", nil)
  74. }
  75. }
  76. return true, nil
  77. }
  78. func TestDataValidationAndGoOn(t *testing.T) {
  79. tests := []struct {
  80. name string
  81. beforeImage *types.RecordImage
  82. afterImage *types.RecordImage
  83. currentImage *types.RecordImage
  84. want bool
  85. wantErr bool
  86. }{
  87. {
  88. name: "before == after, skip rollback",
  89. beforeImage: &types.RecordImage{
  90. TableName: "t_user",
  91. Rows: []types.RowImage{
  92. {Columns: []types.ColumnImage{
  93. {ColumnName: "id", Value: 1},
  94. {ColumnName: "name", Value: "a"},
  95. }},
  96. },
  97. },
  98. afterImage: &types.RecordImage{
  99. TableName: "t_user",
  100. Rows: []types.RowImage{
  101. {Columns: []types.ColumnImage{
  102. {ColumnName: "id", Value: 1},
  103. {ColumnName: "name", Value: "a"},
  104. }},
  105. },
  106. },
  107. want: false,
  108. wantErr: false,
  109. },
  110. {
  111. name: "after == current, continue rollback",
  112. beforeImage: &types.RecordImage{
  113. TableName: "t_user",
  114. Rows: []types.RowImage{
  115. {Columns: []types.ColumnImage{
  116. {ColumnName: "id", Value: 1},
  117. {ColumnName: "name", Value: "a"},
  118. }},
  119. },
  120. },
  121. afterImage: &types.RecordImage{
  122. TableName: "t_user",
  123. Rows: []types.RowImage{
  124. {Columns: []types.ColumnImage{
  125. {ColumnName: "id", Value: 1},
  126. {ColumnName: "name", Value: "b"},
  127. }},
  128. },
  129. },
  130. currentImage: &types.RecordImage{
  131. TableName: "t_user",
  132. Rows: []types.RowImage{
  133. {Columns: []types.ColumnImage{
  134. {ColumnName: "id", Value: 1},
  135. {ColumnName: "name", Value: "b"},
  136. }},
  137. },
  138. },
  139. want: true,
  140. wantErr: false,
  141. },
  142. {
  143. name: "current == before, rollback already done",
  144. beforeImage: &types.RecordImage{
  145. TableName: "t_user",
  146. Rows: []types.RowImage{
  147. {Columns: []types.ColumnImage{
  148. {ColumnName: "id", Value: 1},
  149. {ColumnName: "name", Value: "a"},
  150. }},
  151. },
  152. },
  153. afterImage: &types.RecordImage{
  154. TableName: "t_user",
  155. Rows: []types.RowImage{
  156. {Columns: []types.ColumnImage{
  157. {ColumnName: "id", Value: 1},
  158. {ColumnName: "name", Value: "b"},
  159. }},
  160. },
  161. },
  162. currentImage: &types.RecordImage{
  163. TableName: "t_user",
  164. Rows: []types.RowImage{
  165. {Columns: []types.ColumnImage{
  166. {ColumnName: "id", Value: 1},
  167. {ColumnName: "name", Value: "a"},
  168. }},
  169. },
  170. },
  171. want: false,
  172. wantErr: false,
  173. },
  174. {
  175. name: "dirty data",
  176. beforeImage: &types.RecordImage{
  177. TableName: "t_user",
  178. Rows: []types.RowImage{
  179. {Columns: []types.ColumnImage{
  180. {ColumnName: "id", Value: 1},
  181. {ColumnName: "name", Value: "a"},
  182. }},
  183. },
  184. },
  185. afterImage: &types.RecordImage{
  186. TableName: "t_user",
  187. Rows: []types.RowImage{
  188. {Columns: []types.ColumnImage{
  189. {ColumnName: "id", Value: 1},
  190. {ColumnName: "name", Value: "b"},
  191. }},
  192. },
  193. },
  194. currentImage: &types.RecordImage{
  195. TableName: "t_user",
  196. Rows: []types.RowImage{
  197. {Columns: []types.ColumnImage{
  198. {ColumnName: "id", Value: 1},
  199. {ColumnName: "name", Value: "c"},
  200. }},
  201. },
  202. },
  203. want: false,
  204. wantErr: true,
  205. },
  206. }
  207. for _, tt := range tests {
  208. t.Run(tt.name, func(t *testing.T) {
  209. // patch UndoConfig
  210. cfgPatch := gomonkey.ApplyGlobalVar(&undo.UndoConfig, undo.Config{DataValidation: true})
  211. defer cfgPatch.Reset()
  212. // patch IsRecordsEquals
  213. comparePatch := gomonkey.ApplyFunc(IsRecordsEquals, func(a, b *types.RecordImage) (bool, error) {
  214. aj, _ := json.Marshal(a.Rows)
  215. bj, _ := json.Marshal(b.Rows)
  216. return string(aj) == string(bj), nil
  217. })
  218. defer comparePatch.Reset()
  219. executor := &testableBaseExecutor{
  220. BaseExecutor: BaseExecutor{
  221. sqlUndoLog: undo.SQLUndoLog{
  222. BeforeImage: tt.beforeImage,
  223. AfterImage: tt.afterImage,
  224. },
  225. undoImage: tt.afterImage,
  226. },
  227. mockCurrentImage: tt.currentImage,
  228. }
  229. got, err := executor.dataValidationAndGoOn(context.Background(), nil)
  230. assert.Equal(t, tt.want, got)
  231. if tt.wantErr {
  232. var be *serr.SeataError
  233. if errors.As(err, &be) {
  234. assert.Equal(t, serr.SQLUndoDirtyError, be.Code)
  235. } else {
  236. t.Errorf("expected BusinessError, got: %v", err)
  237. }
  238. } else {
  239. assert.NoError(t, err)
  240. }
  241. })
  242. }
  243. }