diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index b8ebcf9e02d58c15298a0a00e545b7583564c9a9..09b15a255b1b89ccd6a495169ba5827d98db481f 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -6,7 +6,7 @@ before_script: - cd /go/src/code.vereign.com/${CI_PROJECT_PATH} unit tests: - image: golang:1.17.7 + image: golang:1.17.8 stage: test tags: - amd64-docker @@ -16,7 +16,7 @@ unit tests: - go tool cover -func=coverage.out lint: - image: golangci/golangci-lint:v1.44.2 + image: golangci/golangci-lint:v1.45.0 stage: test tags: - amd64-docker diff --git a/README.md b/README.md index 9a5fd63e199e82e6cc63821915c1375376637647..bac6c0e97ed2dd482ce262a7c505b6241c0ff019 100644 --- a/README.md +++ b/README.md @@ -3,4 +3,4 @@ # golib -Go library with utility packages shared across multiple services. \ No newline at end of file +Go library with utility packages used in TSA backend services. \ No newline at end of file diff --git a/graceful/graceful.go b/graceful/graceful.go new file mode 100644 index 0000000000000000000000000000000000000000..7e7387d98d5546cf30d4b7c51240c9b64f6e5381 --- /dev/null +++ b/graceful/graceful.go @@ -0,0 +1,45 @@ +package graceful + +import ( + "context" + "net/http" + "os" + "os/signal" + "syscall" + "time" +) + +// Shutdown gracefully stops the given HTTP server on +// receiving a stop signal or context cancellation signal +// and waits for the active connections to be closed +// for {timeout} period of time. +// +// The {timeout} period is respected in both stop conditions. +func Shutdown(ctx context.Context, srv *http.Server, timeout time.Duration) error { + done := make(chan error, 1) + go func() { + c := make(chan os.Signal, 1) + signal.Notify(c, os.Interrupt, syscall.SIGTERM) + + // wait for a signal or context cancellation + select { + case <-c: + case <-ctx.Done(): + } + + ctx := context.Background() + var cancel context.CancelFunc + if timeout > 0 { + ctx, cancel = context.WithTimeout(ctx, timeout) + defer cancel() + } + + done <- srv.Shutdown(ctx) + }() + + if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed { + return err + } + + return <-done +} diff --git a/graceful/graceful_test.go b/graceful/graceful_test.go new file mode 100644 index 0000000000000000000000000000000000000000..df4755f0dc0fb64ad397fc5ab29fe289cbc1b598 --- /dev/null +++ b/graceful/graceful_test.go @@ -0,0 +1,183 @@ +package graceful_test + +import ( + "context" + "errors" + "net/http" + "os" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "code.vereign.com/gaiax/tsa/golib/graceful" +) + +type handler struct { + requestTime time.Duration +} + +func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + time.Sleep(h.requestTime) +} + +func TestShutdownWithSignal(t *testing.T) { + tests := []struct { + // input + name string + addr string + timeout time.Duration + reqTime time.Duration + + // desired outcome + err error + maxShutdownDuration time.Duration + }{ + { + name: "without timeout", + addr: ":58430", + timeout: 0, + reqTime: 200 * time.Millisecond, + + err: nil, + maxShutdownDuration: 250 * time.Millisecond, + }, + { + name: "with timeout higher than request processing time", + addr: ":58431", + timeout: 500 * time.Millisecond, + reqTime: 200 * time.Millisecond, + + err: nil, + maxShutdownDuration: 250 * time.Millisecond, + }, + { + name: "with timeout lower than request processing time", + addr: ":58432", + timeout: 50 * time.Millisecond, + reqTime: 200 * time.Millisecond, + + err: errors.New("context deadline exceeded"), + maxShutdownDuration: 100 * time.Millisecond, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + srv := &http.Server{ + Addr: test.addr, + Handler: &handler{requestTime: test.reqTime}, + } + + reqerr := make(chan error, 1) + go func() { + if err := graceful.Shutdown(context.Background(), srv, test.timeout); err != nil { + reqerr <- err + } + }() + + go func() { + _, err := http.Get("http://localhost" + test.addr) + reqerr <- err + }() + + // wait a while so the HTTP request could be sent + time.Sleep(50 * time.Millisecond) + + start := time.Now() + + proc, err := os.FindProcess(os.Getpid()) + require.NoError(t, err) + require.NoError(t, proc.Signal(os.Interrupt)) + + err = <-reqerr + + if test.err != nil { + assert.EqualError(t, err, test.err.Error()) + } else { + assert.NoError(t, err) + } + + assert.True(t, time.Since(start) < test.maxShutdownDuration) + }) + } +} + +func TestShutdownWithContext(t *testing.T) { + tests := []struct { + // input + name string + contextTimeout time.Duration + addr string + timeout time.Duration + reqTime time.Duration + + // desired outcome + err error + maxShutdownDuration time.Duration + }{ + { + name: "with context timeout higher than request processing time", + addr: ":58431", + contextTimeout: 500 * time.Millisecond, + reqTime: 200 * time.Millisecond, + + err: nil, + maxShutdownDuration: 250 * time.Millisecond, + }, + { + name: "context timeout lower than request processing time", + addr: ":58432", + timeout: 10 * time.Millisecond, + contextTimeout: 100 * time.Millisecond, + reqTime: 300 * time.Millisecond, + + err: errors.New("context deadline exceeded"), + maxShutdownDuration: 150 * time.Millisecond, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + srv := &http.Server{ + Addr: test.addr, + Handler: &handler{requestTime: test.reqTime}, + } + + ctx := context.Background() + var cancel context.CancelFunc + if test.contextTimeout > 0 { + ctx, cancel = context.WithTimeout(ctx, test.contextTimeout) + defer cancel() + } + + reqerr := make(chan error, 1) + go func() { + if err := graceful.Shutdown(ctx, srv, test.timeout); err != nil { + reqerr <- err + } + }() + + go func() { + _, err := http.Get("http://localhost" + test.addr) + reqerr <- err + }() + + start := time.Now() + + // wait a while so the HTTP request could be sent + time.Sleep(50 * time.Millisecond) + + err := <-reqerr + + if test.err != nil { + assert.EqualError(t, err, test.err.Error()) + } else { + assert.NoError(t, err) + } + + assert.True(t, time.Since(start) < test.maxShutdownDuration) + }) + } +}