diff --git a/auth/auth.go b/auth/auth.go new file mode 100644 index 0000000000000000000000000000000000000000..41aeb4e415feca5578da8fa60845196d43e5e88c --- /dev/null +++ b/auth/auth.go @@ -0,0 +1,77 @@ +package auth + +import ( + "context" + "fmt" + "net/http" + "strings" + "time" + + "github.com/lestrrat-go/jwx/v2/jwk" + "github.com/lestrrat-go/jwx/v2/jwt" +) + +// AuthMiddleware is standard HTTP middleware used for authenticating +// requests carrying a bearer JWT token. +// +// It uses an internal caching mechanism for fetching Json Web Keys from +// a given URL and automatically refreshes the cache on a given time interval. +// +// JWT tokens are expected to carry a Header *kid* claim specifying the +// ID of the public key which should be used for verification. +type AuthMiddleware struct { + jwkSet jwk.Set +} + +func NewMiddleware(jwkURL string, refreshInterval time.Duration, c *http.Client) (*AuthMiddleware, error) { + if jwkURL == "" { + return nil, fmt.Errorf("missing JWK url") + } + + cache := jwk.NewCache(context.Background()) + if err := cache.Register(jwkURL, jwk.WithHTTPClient(c), jwk.WithRefreshInterval(refreshInterval)); err != nil { + return nil, fmt.Errorf("fail to register JWK url with cache: %v", err) + } + _, err := cache.Refresh(context.Background(), jwkURL) + if err != nil { + return nil, fmt.Errorf("fail to refresh JWK cache: %v", err) + } + + return &AuthMiddleware{ + jwkSet: jwk.NewCachedSet(cache, jwkURL), + }, nil +} + +func (a *AuthMiddleware) Handler() func(http.Handler) http.Handler { + return func(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + token, err := tokenFromRequest(r) + if err != nil { + http.Error(w, err.Error(), http.StatusUnauthorized) + return + } + + _, err = jwt.Parse([]byte(token), jwt.WithKeySet(a.jwkSet)) + if err != nil { + http.Error(w, err.Error(), http.StatusUnauthorized) + return + } + + h.ServeHTTP(w, r) + }) + } +} + +func tokenFromRequest(r *http.Request) (string, error) { + authHeader := r.Header.Get("Authorization") + auth := strings.Split(authHeader, " ") + if len(auth) != 2 { + return "", fmt.Errorf("invalid authorization header") + } + + if auth[0] != "Bearer" { + return "", fmt.Errorf("invalid authorization header") + } + + return auth[1], nil +} diff --git a/auth/auth_test.go b/auth/auth_test.go new file mode 100644 index 0000000000000000000000000000000000000000..87431ceb223e9c6e44232ddaacb64d42b576ddae --- /dev/null +++ b/auth/auth_test.go @@ -0,0 +1,183 @@ +package auth_test + +import ( + "crypto/rand" + "crypto/rsa" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/lestrrat-go/jwx/v2/jwa" + "github.com/lestrrat-go/jwx/v2/jwk" + "github.com/lestrrat-go/jwx/v2/jwt" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "gitlab.com/gaia-x/data-infrastructure-federation-services/tsa/golib/auth" +) + +var ( + publicKey jwk.RSAPublicKey + privateKey jwk.RSAPrivateKey +) + +// initKeys creates private and public RSA keys and sets them +// in global variables that are used by all tests. +func initKeys() error { + rawprivkey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return fmt.Errorf("failed to create raw private key: %v", err) + } + + privkey, err := jwk.FromRaw(rawprivkey) + if err != nil { + return fmt.Errorf("failed to create private key: %v", err) + } + + pubkey, err := privkey.PublicKey() + if err != nil { + return fmt.Errorf("failed to create public key: %v", err) + } + + privk, ok := privkey.(jwk.RSAPrivateKey) + if !ok { + return fmt.Errorf("cannot cast private key to RSA private key") + } + privateKey = privk + + if err := privateKey.Set(jwk.KeyIDKey, "key1"); err != nil { + return fmt.Errorf("cannot set kid value to private key: %v", err) + } + + pubk, ok := pubkey.(jwk.RSAPublicKey) + if !ok { + return fmt.Errorf("cannot cast public key to RSA public key") + } + publicKey = pubk + + return nil +} + +func TestAuthMiddleware_Handler(t *testing.T) { + err := initKeys() + require.NoError(t, err) + + keyServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Create a JWK Set + set := jwk.NewSet() + + var raw interface{} + err := publicKey.Raw(&raw) + assert.NoError(t, err) + + key, err := jwk.FromRaw(raw) + assert.NoError(t, err) + + err = key.Set(jwk.AlgorithmKey, jwa.RS256) + assert.NoError(t, err) + + err = key.Set("kid", "key1") + assert.NoError(t, err) + + err = set.AddKey(key) + assert.NoError(t, err) + + err = json.NewEncoder(w).Encode(set) + assert.NoError(t, err) + })) + + authMiddleware, err := auth.NewMiddleware(keyServer.URL, 1*time.Hour, http.DefaultClient) + require.NoError(t, err) + + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write([]byte("everything is fine")) + }) + authHandler := authMiddleware.Handler()(handler) + + t.Run("authenticate with valid token", func(t *testing.T) { + req, err := http.NewRequest(http.MethodGet, "https://example.com", nil) + assert.NoError(t, err) + + token, err := createSignedToken() + assert.NoError(t, err) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) + + response := httptest.NewRecorder() + authHandler.ServeHTTP(response, req) + + assert.Equal(t, "everything is fine", response.Body.String()) + }) + + t.Run("authenticate with invalid token", func(t *testing.T) { + req, err := http.NewRequest(http.MethodGet, "https://example.com", nil) + assert.NoError(t, err) + + token := "deadbeef" + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) + + response := httptest.NewRecorder() + authHandler.ServeHTTP(response, req) + + assert.Equal(t, "failed to parse jws: invalid compact serialization format: invalid number of segments\n", response.Body.String()) + }) + + t.Run("authenticate with token signed with unknown key (invalid signature)", func(t *testing.T) { + req, err := http.NewRequest(http.MethodGet, "https://example.com", nil) + assert.NoError(t, err) + + token := "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiIsImtpZCI6ImtleTEifQ.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWUsImlhdCI6MTY2NDk2MDg2OCwiZXhwIjoxNjY0OTY0NDY4fQ.FrIA3A228den86qX72o4yP3TiEA9uOf46Yav_vY5daCQ8yeAm3GaBzC_nikt0y9NSCR6K2G2GCm7RcdfP3vQ9CFh2R7FtL4nfjffdauLmXVzp3z_lyBIKYL3RsTGChctfMeYZzk2F6EDmGHeI8xV3KiDC5Gfkvfdp9MfFxVy7DcuEV9MLo_9j4Y-7nfuB1CbdF_1vzSsO0twitePjsB59CNndugJgTUGFjKUJU2_e7vKMR_i9NvFHfJZS2VbtX3vrZ5f_pfOvBSSZJBxG50Uwf6COhtABieVHhhmLBSJq1P1EWRAI26Bk-YtE8k-jfjra9W1RF5DLF7Jh9Lw-utc5A" + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) + + response := httptest.NewRecorder() + authHandler.ServeHTTP(response, req) + + assert.Equal(t, "could not verify message using any of the signatures or keys\n", response.Body.String()) + }) + + t.Run("request without token", func(t *testing.T) { + req, err := http.NewRequest(http.MethodGet, "https://example.com", nil) + assert.NoError(t, err) + + response := httptest.NewRecorder() + authHandler.ServeHTTP(response, req) + + assert.Equal(t, "invalid authorization header\n", response.Body.String()) + }) + + t.Run("invalid authorization header", func(t *testing.T) { + req, err := http.NewRequest(http.MethodGet, "https://example.com", nil) + assert.NoError(t, err) + + token, err := createSignedToken() + assert.NoError(t, err) + req.Header.Set("Authorization", fmt.Sprintf("%s", token)) + + response := httptest.NewRecorder() + authHandler.ServeHTTP(response, req) + + assert.Equal(t, "invalid authorization header\n", response.Body.String()) + }) +} + +func createSignedToken() (string, error) { + token, err := jwt.NewBuilder(). + Claim(`claim1`, `value1`). + Claim(`claim2`, `value2`). + Issuer(`https://example.com`). + Subject("terminator"). + Audience([]string{"skynet"}). + Build() + if err != nil { + return "", fmt.Errorf("failed to build token: %s\n", err) + } + + signed, err := jwt.Sign(token, jwt.WithKey(jwa.RS256, privateKey)) + if err != nil { + return "", err + } + + return string(signed), nil +} diff --git a/go.mod b/go.mod index 173215aa8e82394bfb7b78e5c20f55eeccf44232..9a534666af475389675023d1e36a8dbe40754d02 100644 --- a/go.mod +++ b/go.mod @@ -3,17 +3,26 @@ module gitlab.com/gaia-x/data-infrastructure-federation-services/tsa/golib go 1.19 require ( + github.com/lestrrat-go/jwx/v2 v2.0.6 github.com/stretchr/testify v1.8.0 goa.design/goa/v3 v3.8.5 ) require ( github.com/davecgh/go-spew v1.1.1 // indirect + github.com/decred/dcrd/dcrec/secp256k1/v4 v4.1.0 // indirect github.com/dimfeld/httptreemux/v5 v5.4.0 // indirect + github.com/goccy/go-json v0.9.11 // indirect github.com/google/uuid v1.3.0 // indirect github.com/gorilla/websocket v1.5.0 // indirect github.com/hashicorp/errwrap v1.1.0 // indirect github.com/hashicorp/go-multierror v1.1.1 // indirect + github.com/lestrrat-go/blackmagic v1.0.1 // indirect + github.com/lestrrat-go/httpcc v1.0.1 // indirect + github.com/lestrrat-go/httprc v1.0.4 // indirect + github.com/lestrrat-go/iter v1.0.2 // indirect + github.com/lestrrat-go/option v1.0.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect + golang.org/x/crypto v0.0.0-20220427172511-eb4f295cb31f // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 2e82aa70958e625f889e96efba0c75ec39eca7f4..9467b09eacb32d02fe6ca5a6a73ea46637c6f61a 100644 --- a/go.sum +++ b/go.sum @@ -1,8 +1,13 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/decred/dcrd/crypto/blake256 v1.0.0/go.mod h1:sQl2p6Y26YV+ZOcSTP6thNdn47hh8kt6rqSlvmrXFAc= +github.com/decred/dcrd/dcrec/secp256k1/v4 v4.1.0 h1:HbphB4TFFXpv7MNrT52FGrrgVXF1owhMVTHFZIlnvd4= +github.com/decred/dcrd/dcrec/secp256k1/v4 v4.1.0/go.mod h1:DZGJHZMqrU4JJqFAWUS2UO1+lbSKsdiOoYi9Zzey7Fc= github.com/dimfeld/httptreemux/v5 v5.4.0 h1:IiHYEjh+A7pYbhWyjmGnj5HZK6gpOOvyBXCJ+BE8/Gs= github.com/dimfeld/httptreemux/v5 v5.4.0/go.mod h1:QeEylH57C0v3VO0tkKraVz9oD3Uu93CKPnTLbsidvSw= +github.com/goccy/go-json v0.9.11 h1:/pAaQDLHEoCq/5FFmSKBswWmK6H0e8g4159Kc/X/nqk= +github.com/goccy/go-json v0.9.11/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= @@ -12,15 +17,37 @@ github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo= github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM= +github.com/lestrrat-go/blackmagic v1.0.1 h1:lS5Zts+5HIC/8og6cGHb0uCcNCa3OUt1ygh3Qz2Fe80= +github.com/lestrrat-go/blackmagic v1.0.1/go.mod h1:UrEqBzIR2U6CnzVyUtfM6oZNMt/7O7Vohk2J0OGSAtU= +github.com/lestrrat-go/httpcc v1.0.1 h1:ydWCStUeJLkpYyjLDHihupbn2tYmZ7m22BGkcvZZrIE= +github.com/lestrrat-go/httpcc v1.0.1/go.mod h1:qiltp3Mt56+55GPVCbTdM9MlqhvzyuL6W/NMDA8vA5E= +github.com/lestrrat-go/httprc v1.0.4 h1:bAZymwoZQb+Oq8MEbyipag7iSq6YIga8Wj6GOiJGdI8= +github.com/lestrrat-go/httprc v1.0.4/go.mod h1:mwwz3JMTPBjHUkkDv/IGJ39aALInZLrhBp0X7KGUZlo= +github.com/lestrrat-go/iter v1.0.2 h1:gMXo1q4c2pHmC3dn8LzRhJfP1ceCbgSiT9lUydIzltI= +github.com/lestrrat-go/iter v1.0.2/go.mod h1:Momfcq3AnRlRjI5b5O8/G5/BvpzrhoFTZcn06fEOPt4= +github.com/lestrrat-go/jwx/v2 v2.0.6 h1:RlyYNLV892Ed7+FTfj1ROoF6x7WxL965PGTHso/60G0= +github.com/lestrrat-go/jwx/v2 v2.0.6/go.mod h1:aVrGuwEr3cp2Prw6TtQvr8sQxe+84gruID5C9TxT64Q= +github.com/lestrrat-go/option v1.0.0 h1:WqAWL8kh8VcSoD6xjSH34/1m8yxluXQbDeKNfvFeEO4= +github.com/lestrrat-go/option v1.0.0/go.mod h1:5ZHFbivi4xwXxhxY9XHDe2FHo6/Z7WWmtT7T5nBBp3I= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= goa.design/goa/v3 v3.8.5 h1:Y0/6ZwmwZftqQBOlBANU9mP4R+h2gIQUyfQMEs98pGU= goa.design/goa/v3 v3.8.5/go.mod h1:+tEl2wNEL54TMAQQ5Mu5il1zl20/7k89XMUv8hVJfa8= +golang.org/x/crypto v0.0.0-20220427172511-eb4f295cb31f h1:OeJjE6G4dgCY4PIXvIRQbE8+RX+uXZyGhUy/ksMGJoc= +golang.org/x/crypto v0.0.0-20220427172511-eb4f295cb31f/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= +golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=