Skip to content
Snippets Groups Projects
Commit 854b152c authored by Lyuben Penkovski's avatar Lyuben Penkovski
Browse files

Add package for graceful shutdown of HTTP servers

parent 3c89eed5
No related branches found
No related tags found
1 merge request!4Add package for graceful shutdown of HTTP servers
Pipeline #49426 passed with stage
in 20 seconds
......@@ -6,7 +6,7 @@ before_script:
- cd /go/src/code.vereign.com/${CI_PROJECT_PATH}
unit tests:
image: golang:1.17.7
image: golang:1.17.8
stage: test
tags:
- amd64-docker
......@@ -16,7 +16,7 @@ unit tests:
- go tool cover -func=coverage.out
lint:
image: golangci/golangci-lint:v1.44.2
image: golangci/golangci-lint:v1.45.0
stage: test
tags:
- amd64-docker
......
......@@ -3,4 +3,4 @@
# golib
Go library with utility packages shared across multiple services.
\ No newline at end of file
Go library with utility packages used in TSA backend services.
\ No newline at end of file
package graceful
import (
"context"
"net/http"
"os"
"os/signal"
"syscall"
"time"
)
// Shutdown gracefully stops the given HTTP server on
// receiving a stop signal or context cancellation signal
// and waits for the active connections to be closed
// for {timeout} period of time.
//
// The {timeout} period is respected in both stop conditions.
func Shutdown(ctx context.Context, srv *http.Server, timeout time.Duration) error {
done := make(chan error, 1)
go func() {
c := make(chan os.Signal, 1)
signal.Notify(c, os.Interrupt, syscall.SIGTERM)
// wait for a signal or context cancellation
select {
case <-c:
case <-ctx.Done():
}
ctx := context.Background()
var cancel context.CancelFunc
if timeout > 0 {
ctx, cancel = context.WithTimeout(ctx, timeout)
defer cancel()
}
done <- srv.Shutdown(ctx)
}()
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
return err
}
return <-done
}
package graceful_test
import (
"context"
"errors"
"net/http"
"os"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"code.vereign.com/gaiax/tsa/golib/graceful"
)
type handler struct {
requestTime time.Duration
}
func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
time.Sleep(h.requestTime)
}
func TestShutdownWithSignal(t *testing.T) {
tests := []struct {
// input
name string
addr string
timeout time.Duration
reqTime time.Duration
// desired outcome
err error
maxShutdownDuration time.Duration
}{
{
name: "without timeout",
addr: ":58430",
timeout: 0,
reqTime: 200 * time.Millisecond,
err: nil,
maxShutdownDuration: 250 * time.Millisecond,
},
{
name: "with timeout higher than request processing time",
addr: ":58431",
timeout: 500 * time.Millisecond,
reqTime: 200 * time.Millisecond,
err: nil,
maxShutdownDuration: 250 * time.Millisecond,
},
{
name: "with timeout lower than request processing time",
addr: ":58432",
timeout: 50 * time.Millisecond,
reqTime: 200 * time.Millisecond,
err: errors.New("context deadline exceeded"),
maxShutdownDuration: 100 * time.Millisecond,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
srv := &http.Server{
Addr: test.addr,
Handler: &handler{requestTime: test.reqTime},
}
reqerr := make(chan error, 1)
go func() {
if err := graceful.Shutdown(context.Background(), srv, test.timeout); err != nil {
reqerr <- err
}
}()
go func() {
_, err := http.Get("http://localhost" + test.addr)
reqerr <- err
}()
// wait a while so the HTTP request could be sent
time.Sleep(50 * time.Millisecond)
start := time.Now()
proc, err := os.FindProcess(os.Getpid())
require.NoError(t, err)
require.NoError(t, proc.Signal(os.Interrupt))
err = <-reqerr
if test.err != nil {
assert.EqualError(t, err, test.err.Error())
} else {
assert.NoError(t, err)
}
assert.True(t, time.Since(start) < test.maxShutdownDuration)
})
}
}
func TestShutdownWithContext(t *testing.T) {
tests := []struct {
// input
name string
contextTimeout time.Duration
addr string
timeout time.Duration
reqTime time.Duration
// desired outcome
err error
maxShutdownDuration time.Duration
}{
{
name: "with context timeout higher than request processing time",
addr: ":58431",
contextTimeout: 500 * time.Millisecond,
reqTime: 200 * time.Millisecond,
err: nil,
maxShutdownDuration: 250 * time.Millisecond,
},
{
name: "context timeout lower than request processing time",
addr: ":58432",
timeout: 10 * time.Millisecond,
contextTimeout: 100 * time.Millisecond,
reqTime: 300 * time.Millisecond,
err: errors.New("context deadline exceeded"),
maxShutdownDuration: 150 * time.Millisecond,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
srv := &http.Server{
Addr: test.addr,
Handler: &handler{requestTime: test.reqTime},
}
ctx := context.Background()
var cancel context.CancelFunc
if test.contextTimeout > 0 {
ctx, cancel = context.WithTimeout(ctx, test.contextTimeout)
defer cancel()
}
reqerr := make(chan error, 1)
go func() {
if err := graceful.Shutdown(ctx, srv, test.timeout); err != nil {
reqerr <- err
}
}()
go func() {
_, err := http.Get("http://localhost" + test.addr)
reqerr <- err
}()
start := time.Now()
// wait a while so the HTTP request could be sent
time.Sleep(50 * time.Millisecond)
err := <-reqerr
if test.err != nil {
assert.EqualError(t, err, test.err.Error())
} else {
assert.NoError(t, err)
}
assert.True(t, time.Since(start) < test.maxShutdownDuration)
})
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment