Skip to content
Snippets Groups Projects
Commit 854b152c authored by Lyuben Penkovski's avatar Lyuben Penkovski
Browse files

Add package for graceful shutdown of HTTP servers

parent 3c89eed5
Branches
Tags
1 merge request!4Add package for graceful shutdown of HTTP servers
Pipeline #49426 passed
......@@ -6,7 +6,7 @@ before_script:
- cd /go/src/code.vereign.com/${CI_PROJECT_PATH}
unit tests:
image: golang:1.17.7
image: golang:1.17.8
stage: test
tags:
- amd64-docker
......@@ -16,7 +16,7 @@ unit tests:
- go tool cover -func=coverage.out
lint:
image: golangci/golangci-lint:v1.44.2
image: golangci/golangci-lint:v1.45.0
stage: test
tags:
- amd64-docker
......
......@@ -3,4 +3,4 @@
# golib
Go library with utility packages shared across multiple services.
\ No newline at end of file
Go library with utility packages used in TSA backend services.
\ No newline at end of file
package graceful
import (
"context"
"net/http"
"os"
"os/signal"
"syscall"
"time"
)
// Shutdown gracefully stops the given HTTP server on
// receiving a stop signal or context cancellation signal
// and waits for the active connections to be closed
// for {timeout} period of time.
//
// The {timeout} period is respected in both stop conditions.
func Shutdown(ctx context.Context, srv *http.Server, timeout time.Duration) error {
done := make(chan error, 1)
go func() {
c := make(chan os.Signal, 1)
signal.Notify(c, os.Interrupt, syscall.SIGTERM)
// wait for a signal or context cancellation
select {
case <-c:
case <-ctx.Done():
}
ctx := context.Background()
var cancel context.CancelFunc
if timeout > 0 {
ctx, cancel = context.WithTimeout(ctx, timeout)
defer cancel()
}
done <- srv.Shutdown(ctx)
}()
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
return err
}
return <-done
}
package graceful_test
import (
"context"
"errors"
"net/http"
"os"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"code.vereign.com/gaiax/tsa/golib/graceful"
)
type handler struct {
requestTime time.Duration
}
func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
time.Sleep(h.requestTime)
}
func TestShutdownWithSignal(t *testing.T) {
tests := []struct {
// input
name string
addr string
timeout time.Duration
reqTime time.Duration
// desired outcome
err error
maxShutdownDuration time.Duration
}{
{
name: "without timeout",
addr: ":58430",
timeout: 0,
reqTime: 200 * time.Millisecond,
err: nil,
maxShutdownDuration: 250 * time.Millisecond,
},
{
name: "with timeout higher than request processing time",
addr: ":58431",
timeout: 500 * time.Millisecond,
reqTime: 200 * time.Millisecond,
err: nil,
maxShutdownDuration: 250 * time.Millisecond,
},
{
name: "with timeout lower than request processing time",
addr: ":58432",
timeout: 50 * time.Millisecond,
reqTime: 200 * time.Millisecond,
err: errors.New("context deadline exceeded"),
maxShutdownDuration: 100 * time.Millisecond,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
srv := &http.Server{
Addr: test.addr,
Handler: &handler{requestTime: test.reqTime},
}
reqerr := make(chan error, 1)
go func() {
if err := graceful.Shutdown(context.Background(), srv, test.timeout); err != nil {
reqerr <- err
}
}()
go func() {
_, err := http.Get("http://localhost" + test.addr)
reqerr <- err
}()
// wait a while so the HTTP request could be sent
time.Sleep(50 * time.Millisecond)
start := time.Now()
proc, err := os.FindProcess(os.Getpid())
require.NoError(t, err)
require.NoError(t, proc.Signal(os.Interrupt))
err = <-reqerr
if test.err != nil {
assert.EqualError(t, err, test.err.Error())
} else {
assert.NoError(t, err)
}
assert.True(t, time.Since(start) < test.maxShutdownDuration)
})
}
}
func TestShutdownWithContext(t *testing.T) {
tests := []struct {
// input
name string
contextTimeout time.Duration
addr string
timeout time.Duration
reqTime time.Duration
// desired outcome
err error
maxShutdownDuration time.Duration
}{
{
name: "with context timeout higher than request processing time",
addr: ":58431",
contextTimeout: 500 * time.Millisecond,
reqTime: 200 * time.Millisecond,
err: nil,
maxShutdownDuration: 250 * time.Millisecond,
},
{
name: "context timeout lower than request processing time",
addr: ":58432",
timeout: 10 * time.Millisecond,
contextTimeout: 100 * time.Millisecond,
reqTime: 300 * time.Millisecond,
err: errors.New("context deadline exceeded"),
maxShutdownDuration: 150 * time.Millisecond,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
srv := &http.Server{
Addr: test.addr,
Handler: &handler{requestTime: test.reqTime},
}
ctx := context.Background()
var cancel context.CancelFunc
if test.contextTimeout > 0 {
ctx, cancel = context.WithTimeout(ctx, test.contextTimeout)
defer cancel()
}
reqerr := make(chan error, 1)
go func() {
if err := graceful.Shutdown(ctx, srv, test.timeout); err != nil {
reqerr <- err
}
}()
go func() {
_, err := http.Get("http://localhost" + test.addr)
reqerr <- err
}()
start := time.Now()
// wait a while so the HTTP request could be sent
time.Sleep(50 * time.Millisecond)
err := <-reqerr
if test.err != nil {
assert.EqualError(t, err, test.err.Error())
} else {
assert.NoError(t, err)
}
assert.True(t, time.Since(start) < test.maxShutdownDuration)
})
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment