|
- package machinery
-
- import (
- "context"
- "errors"
- "fmt"
- "sync"
-
- "github.com/RichardKnop/machinery/v1/backends/result"
- "github.com/RichardKnop/machinery/v1/brokers/eager"
- "github.com/RichardKnop/machinery/v1/config"
- "github.com/RichardKnop/machinery/v1/tasks"
- "github.com/RichardKnop/machinery/v1/tracing"
- "github.com/google/uuid"
-
- backendsiface "github.com/RichardKnop/machinery/v1/backends/iface"
- brokersiface "github.com/RichardKnop/machinery/v1/brokers/iface"
- opentracing "github.com/opentracing/opentracing-go"
- )
-
- // Server is the main Machinery object and stores all configuration
- // All the tasks workers process are registered against the server
- type Server struct {
- config *config.Config
- registeredTasks map[string]interface{}
- broker brokersiface.Broker
- backend backendsiface.Backend
- prePublishHandler func(*tasks.Signature)
- }
-
- // NewServerWithBrokerBackend ...
- func NewServerWithBrokerBackend(cnf *config.Config, brokerServer brokersiface.Broker, backendServer backendsiface.Backend) *Server {
- return &Server{
- config: cnf,
- registeredTasks: make(map[string]interface{}),
- broker: brokerServer,
- backend: backendServer,
- }
- }
-
- // NewServer creates Server instance
- func NewServer(cnf *config.Config) (*Server, error) {
- broker, err := BrokerFactory(cnf)
- if err != nil {
- return nil, err
- }
-
- // Backend is optional so we ignore the error
- backend, _ := BackendFactory(cnf)
-
- srv := NewServerWithBrokerBackend(cnf, broker, backend)
-
- // init for eager-mode
- eager, ok := broker.(eager.Mode)
- if ok {
- // we don't have to call worker.Launch in eager mode
- eager.AssignWorker(srv.NewWorker("eager", 0))
- }
-
- return srv, nil
- }
-
- // NewWorker creates Worker instance
- func (server *Server) NewWorker(consumerTag string, concurrency int) *Worker {
- return &Worker{
- server: server,
- ConsumerTag: consumerTag,
- Concurrency: concurrency,
- Queue: "",
- }
- }
-
- // NewCustomQueueWorker creates Worker instance with Custom Queue
- func (server *Server) NewCustomQueueWorker(consumerTag string, concurrency int, queue string) *Worker {
- return &Worker{
- server: server,
- ConsumerTag: consumerTag,
- Concurrency: concurrency,
- Queue: queue,
- }
- }
-
- // GetBroker returns broker
- func (server *Server) GetBroker() brokersiface.Broker {
- return server.broker
- }
-
- // SetBroker sets broker
- func (server *Server) SetBroker(broker brokersiface.Broker) {
- server.broker = broker
- }
-
- // GetBackend returns backend
- func (server *Server) GetBackend() backendsiface.Backend {
- return server.backend
- }
-
- // SetBackend sets backend
- func (server *Server) SetBackend(backend backendsiface.Backend) {
- server.backend = backend
- }
-
- // GetConfig returns connection object
- func (server *Server) GetConfig() *config.Config {
- return server.config
- }
-
- // SetConfig sets config
- func (server *Server) SetConfig(cnf *config.Config) {
- server.config = cnf
- }
-
- // SetPreTaskHandler Sets pre publish handler
- func (server *Server) SetPreTaskHandler(handler func(*tasks.Signature)) {
- server.prePublishHandler = handler
- }
-
- // RegisterTasks registers all tasks at once
- func (server *Server) RegisterTasks(namedTaskFuncs map[string]interface{}) error {
- for _, task := range namedTaskFuncs {
- if err := tasks.ValidateTask(task); err != nil {
- return err
- }
- }
- server.registeredTasks = namedTaskFuncs
- server.broker.SetRegisteredTaskNames(server.GetRegisteredTaskNames())
- return nil
- }
-
- // RegisterTask registers a single task
- func (server *Server) RegisterTask(name string, taskFunc interface{}) error {
- if err := tasks.ValidateTask(taskFunc); err != nil {
- return err
- }
- server.registeredTasks[name] = taskFunc
- server.broker.SetRegisteredTaskNames(server.GetRegisteredTaskNames())
- return nil
- }
-
- // IsTaskRegistered returns true if the task name is registered with this broker
- func (server *Server) IsTaskRegistered(name string) bool {
- _, ok := server.registeredTasks[name]
- return ok
- }
-
- // GetRegisteredTask returns registered task by name
- func (server *Server) GetRegisteredTask(name string) (interface{}, error) {
- taskFunc, ok := server.registeredTasks[name]
- if !ok {
- return nil, fmt.Errorf("Task not registered error: %s", name)
- }
- return taskFunc, nil
- }
-
- // SendTaskWithContext will inject the trace context in the signature headers before publishing it
- func (server *Server) SendTaskWithContext(ctx context.Context, signature *tasks.Signature) (*result.AsyncResult, error) {
- span, _ := opentracing.StartSpanFromContext(ctx, "SendTask", tracing.ProducerOption(), tracing.MachineryTag)
- defer span.Finish()
-
- // tag the span with some info about the signature
- signature.Headers = tracing.HeadersWithSpan(signature.Headers, span)
-
- // Make sure result backend is defined
- if server.backend == nil {
- return nil, errors.New("Result backend required")
- }
-
- // Auto generate a UUID if not set already
- if signature.UUID == "" {
- taskID := uuid.New().String()
- signature.UUID = fmt.Sprintf("task_%v", taskID)
- }
-
- // Set initial task state to PENDING
- if err := server.backend.SetStatePending(signature); err != nil {
- return nil, fmt.Errorf("Set state pending error: %s", err)
- }
-
- if server.prePublishHandler != nil {
- server.prePublishHandler(signature)
- }
-
- if err := server.broker.Publish(ctx, signature); err != nil {
- return nil, fmt.Errorf("Publish message error: %s", err)
- }
-
- return result.NewAsyncResult(signature, server.backend), nil
- }
-
- // SendTask publishes a task to the default queue
- func (server *Server) SendTask(signature *tasks.Signature) (*result.AsyncResult, error) {
- return server.SendTaskWithContext(context.Background(), signature)
- }
-
- // SendChainWithContext will inject the trace context in all the signature headers before publishing it
- func (server *Server) SendChainWithContext(ctx context.Context, chain *tasks.Chain) (*result.ChainAsyncResult, error) {
- span, _ := opentracing.StartSpanFromContext(ctx, "SendChain", tracing.ProducerOption(), tracing.MachineryTag, tracing.WorkflowChainTag)
- defer span.Finish()
-
- tracing.AnnotateSpanWithChainInfo(span, chain)
-
- return server.SendChain(chain)
- }
-
- // SendChain triggers a chain of tasks
- func (server *Server) SendChain(chain *tasks.Chain) (*result.ChainAsyncResult, error) {
- _, err := server.SendTask(chain.Tasks[0])
- if err != nil {
- return nil, err
- }
-
- return result.NewChainAsyncResult(chain.Tasks, server.backend), nil
- }
-
- // SendGroupWithContext will inject the trace context in all the signature headers before publishing it
- func (server *Server) SendGroupWithContext(ctx context.Context, group *tasks.Group, sendConcurrency int) ([]*result.AsyncResult, error) {
- span, _ := opentracing.StartSpanFromContext(ctx, "SendGroup", tracing.ProducerOption(), tracing.MachineryTag, tracing.WorkflowGroupTag)
- defer span.Finish()
-
- tracing.AnnotateSpanWithGroupInfo(span, group, sendConcurrency)
-
- // Make sure result backend is defined
- if server.backend == nil {
- return nil, errors.New("Result backend required")
- }
-
- asyncResults := make([]*result.AsyncResult, len(group.Tasks))
-
- var wg sync.WaitGroup
- wg.Add(len(group.Tasks))
- errorsChan := make(chan error, len(group.Tasks)*2)
-
- // Init group
- server.backend.InitGroup(group.GroupUUID, group.GetUUIDs())
-
- // Init the tasks Pending state first
- for _, signature := range group.Tasks {
- if err := server.backend.SetStatePending(signature); err != nil {
- errorsChan <- err
- continue
- }
- }
-
- pool := make(chan struct{}, sendConcurrency)
- go func() {
- for i := 0; i < sendConcurrency; i++ {
- pool <- struct{}{}
- }
- }()
-
- for i, signature := range group.Tasks {
-
- if sendConcurrency > 0 {
- <-pool
- }
-
- go func(s *tasks.Signature, index int) {
- defer wg.Done()
-
- // Publish task
-
- err := server.broker.Publish(ctx, s)
-
- if sendConcurrency > 0 {
- pool <- struct{}{}
- }
-
- if err != nil {
- errorsChan <- fmt.Errorf("Publish message error: %s", err)
- return
- }
-
- asyncResults[index] = result.NewAsyncResult(s, server.backend)
- }(signature, i)
- }
-
- done := make(chan int)
- go func() {
- wg.Wait()
- done <- 1
- }()
-
- select {
- case err := <-errorsChan:
- return asyncResults, err
- case <-done:
- return asyncResults, nil
- }
- }
-
- // SendGroup triggers a group of parallel tasks
- func (server *Server) SendGroup(group *tasks.Group, sendConcurrency int) ([]*result.AsyncResult, error) {
- return server.SendGroupWithContext(context.Background(), group, sendConcurrency)
- }
-
- // SendChordWithContext will inject the trace context in all the signature headers before publishing it
- func (server *Server) SendChordWithContext(ctx context.Context, chord *tasks.Chord, sendConcurrency int) (*result.ChordAsyncResult, error) {
- span, _ := opentracing.StartSpanFromContext(ctx, "SendChord", tracing.ProducerOption(), tracing.MachineryTag, tracing.WorkflowChordTag)
- defer span.Finish()
-
- tracing.AnnotateSpanWithChordInfo(span, chord, sendConcurrency)
-
- _, err := server.SendGroupWithContext(ctx, chord.Group, sendConcurrency)
- if err != nil {
- return nil, err
- }
-
- return result.NewChordAsyncResult(
- chord.Group.Tasks,
- chord.Callback,
- server.backend,
- ), nil
- }
-
- // SendChord triggers a group of parallel tasks with a callback
- func (server *Server) SendChord(chord *tasks.Chord, sendConcurrency int) (*result.ChordAsyncResult, error) {
- return server.SendChordWithContext(context.Background(), chord, sendConcurrency)
- }
-
- // GetRegisteredTaskNames returns slice of registered task names
- func (server *Server) GetRegisteredTaskNames() []string {
- taskNames := make([]string, len(server.registeredTasks))
- var i = 0
- for name := range server.registeredTasks {
- taskNames[i] = name
- i++
- }
- return taskNames
- }
|