diff --git a/pkg/task/manager.go b/pkg/task/manager.go index 789ca1a..8ecfc40 100644 --- a/pkg/task/manager.go +++ b/pkg/task/manager.go @@ -1,6 +1,7 @@ package task import ( + "fmt" "sync" ) @@ -54,6 +55,12 @@ func (m *Manager[TCtx]) Start(body TaskBody[TCtx], cmp func(self, other TaskBody return task } +func (m *Manager[TCtx]) StartCmp(body ComparableTaskBody[TCtx]) *Task[TCtx] { + return m.Start(body, func(self, other TaskBody[TCtx]) bool { + return body.Compare(other) + }) +} + func (m *Manager[TCtx]) Find(predicate func(body TaskBody[TCtx]) bool) *Task[TCtx] { m.lock.Lock() defer m.lock.Unlock() @@ -69,7 +76,7 @@ func (m *Manager[TCtx]) Find(predicate func(body TaskBody[TCtx]) bool) *Task[TCt func (m *Manager[TCtx]) executeTask(task *Task[TCtx]) { go func() { - task.body.Execute(m.ctx, func(completing func()) { + task.body.Execute(m.ctx, func(err error, completing func()) { // 删除任务 m.lock.Lock() for i, t := range m.tasks { @@ -82,13 +89,40 @@ func (m *Manager[TCtx]) executeTask(task *Task[TCtx]) { completing() m.lock.Unlock() - // 触发waiter回调 task.waiterLock.Lock() task.isCompleted = true + task.err = err + task.waiterLock.Unlock() + + // 触发回调 for _, w := range task.waiters { close(w) } - task.waiterLock.Unlock() + + for _, c := range task.onCompleted { + c(task) + } }) + + // 如果Task没有调用complete函数就退出了,那么就认为是出错结束 + notCompletedYet := false + task.waiterLock.Lock() + if !task.isCompleted { + task.isCompleted = true + task.err = fmt.Errorf("task exit without calling complete function") + notCompletedYet = true + } + task.waiterLock.Unlock() + + if notCompletedYet { + // 触发回调 + for _, w := range task.waiters { + close(w) + } + + for _, c := range task.onCompleted { + c(task) + } + } }() } diff --git a/pkg/task/task.go b/pkg/task/task.go index c6de810..6fb9413 100644 --- a/pkg/task/task.go +++ b/pkg/task/task.go @@ -2,15 +2,36 @@ package task import "sync" +type CompleteFn = func(err error, completing func()) + type TaskBody[TCtx any] interface { - Execute(ctx TCtx, complete func(completing func())) + Execute(ctx TCtx, complete CompleteFn) +} + +type ComparableTaskBody[TCtx any] interface { + TaskBody[TCtx] + Compare(other TaskBody[TCtx]) bool } type Task[TCtx any] struct { body TaskBody[TCtx] isCompleted bool waiters []chan any + onCompleted []func(task *Task[TCtx]) waiterLock sync.Mutex + err error +} + +func (t *Task[TCtx]) Body() TaskBody[TCtx] { + return t.body +} + +func (t *Task[TCtx]) IsCompleted() bool { + return t.isCompleted +} + +func (t *Task[TCtx]) Error() error { + return t.err } func (t *Task[TCtx]) Wait() { @@ -26,3 +47,15 @@ func (t *Task[TCtx]) Wait() { <-waiter } + +func (t *Task[TCtx]) OnCompleted(callback func(task *Task[TCtx])) { + t.waiterLock.Lock() + if t.isCompleted { + t.waiterLock.Unlock() + callback(t) + return + } + + t.onCompleted = append(t.onCompleted, callback) + t.waiterLock.Unlock() +}