From ce3b97fd057ffb192b67bc3684dbed7247ba5148 Mon Sep 17 00:00:00 2001 From: Yordan Kinkov <yordan.kinkov@vereign.com> Date: Wed, 12 Oct 2022 19:18:03 +0300 Subject: [PATCH] Move header middleware to its own package --- README.md | 4 -- cmd/policy/main.go | 3 +- internal/middleware/header.go | 32 ++++++++++++ internal/service/policy/service.go | 48 ++++++----------- internal/service/policy/service_test.go | 69 +++++++++++-------------- 5 files changed, 79 insertions(+), 77 deletions(-) create mode 100644 internal/middleware/header.go diff --git a/README.md b/README.md index a2fae488..0eee1a63 100644 --- a/README.md +++ b/README.md @@ -174,10 +174,6 @@ endpoints for working with arbitrary dynamically uploaded policies. HTTP Request Headers are passed to the evaluation runtime on each request. One could access any header by name within the Rego source code using `input.header.name` or `input.header["name"]`. -##### **Important:** -The key `header` is forbidden for request body on the `evaluation` endpoint. Sending a `header` key in -the request body results in `400 Bad Request` response status code and the server will not process the request. - ### Policy Extensions Functions A brief documentation for the available Rego extensions functions diff --git a/cmd/policy/main.go b/cmd/policy/main.go index 86ec1942..71dccc16 100644 --- a/cmd/policy/main.go +++ b/cmd/policy/main.go @@ -29,6 +29,7 @@ import ( goapolicy "gitlab.com/gaia-x/data-infrastructure-federation-services/tsa/policy/gen/policy" "gitlab.com/gaia-x/data-infrastructure-federation-services/tsa/policy/internal/clients/cache" "gitlab.com/gaia-x/data-infrastructure-federation-services/tsa/policy/internal/config" + header "gitlab.com/gaia-x/data-infrastructure-federation-services/tsa/policy/internal/middleware" "gitlab.com/gaia-x/data-infrastructure-federation-services/tsa/policy/internal/regocache" "gitlab.com/gaia-x/data-infrastructure-federation-services/tsa/policy/internal/regofunc" "gitlab.com/gaia-x/data-infrastructure-federation-services/tsa/policy/internal/service" @@ -157,7 +158,7 @@ func main() { } // Apply middlewares on the servers - policyServer.Evaluate = policy.HeadersMiddleware()(policyServer.Evaluate) + policyServer.Evaluate = header.Middleware()(policyServer.Evaluate) // Configure the mux. goapolicysrv.Mount(mux, policyServer) diff --git a/internal/middleware/header.go b/internal/middleware/header.go new file mode 100644 index 00000000..5c179f2b --- /dev/null +++ b/internal/middleware/header.go @@ -0,0 +1,32 @@ +package header + +import ( + "context" + "net/http" +) + +type key string + +const headerKey key = "header" + +// Middleware is an HTTP server middleware that gets all HTTP headers +// and adds them to a request context value. +func Middleware() func(http.Handler) http.Handler { + return func(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := ToContext(r.Context(), r) + req := r.WithContext(ctx) + + h.ServeHTTP(w, req) + }) + } +} + +func ToContext(ctx context.Context, r *http.Request) context.Context { + return context.WithValue(ctx, headerKey, r.Header) +} + +func FromContext(ctx context.Context) (http.Header, bool) { + header, ok := ctx.Value(headerKey).(http.Header) + return header, ok +} diff --git a/internal/service/policy/service.go b/internal/service/policy/service.go index dbc1d9f5..afca8662 100644 --- a/internal/service/policy/service.go +++ b/internal/service/policy/service.go @@ -4,7 +4,6 @@ import ( "context" "encoding/json" "fmt" - "net/http" "github.com/google/uuid" "github.com/open-policy-agent/opa/rego" @@ -13,6 +12,7 @@ import ( "gitlab.com/gaia-x/data-infrastructure-federation-services/tsa/golib/errors" "gitlab.com/gaia-x/data-infrastructure-federation-services/tsa/policy/gen/policy" + header2 "gitlab.com/gaia-x/data-infrastructure-federation-services/tsa/policy/internal/middleware" "gitlab.com/gaia-x/data-infrastructure-federation-services/tsa/policy/internal/regofunc" "gitlab.com/gaia-x/data-infrastructure-federation-services/tsa/policy/internal/storage" ) @@ -21,7 +21,7 @@ import ( //go:generate counterfeiter . Storage //go:generate counterfeiter . RegoCache -const HeadersKey = "header" +const HeaderKey = "header" type Cache interface { Set(ctx context.Context, key, namespace, scope string, value []byte, ttl int) error @@ -278,41 +278,23 @@ func (s *Service) queryCacheKey(group, policyName, version string) string { return fmt.Sprintf("%s,%s,%s", group, policyName, version) } -// HeadersMiddleware is an HTTP server middleware that gets all HTTP headers -// and adds them to a request context value. -func HeadersMiddleware() func(http.Handler) http.Handler { - return func(h http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - headers := map[string]string{} - - // get all http headers and add them to a newly initialized request context. - for name := range r.Header { - headers[name] = r.Header.Get(name) - } - ctx := context.WithValue(r.Context(), HeadersKey, headers) //nolint:all - req := r.WithContext(ctx) - - // call initial handler. - h.ServeHTTP(w, req) - }) - } -} - -func (s *Service) addHeadersToEvaluateInput(ctx context.Context, req *policy.EvaluateRequest) (interface{}, error) { - bytes, err := json.Marshal(req.Input) - if err != nil { - return nil, err +func (s *Service) addHeadersToEvaluateInput(ctx context.Context, req *policy.EvaluateRequest) (map[string]interface{}, error) { + i, ok := req.Input.(*interface{}) + if !ok { + return nil, errors.New("unexpected request body: unsuccessful casting to interface") } - var input map[string]interface{} - if err := json.Unmarshal(bytes, &input); err != nil { - return nil, err + i2 := *i + input, ok := i2.(map[string]interface{}) + if !ok { + return nil, errors.New("unexpected request body: unsuccessful casting to map") } - if _, ok := input[HeadersKey]; ok { - return nil, errors.New(errors.BadRequest, fmt.Sprintf("key `%s` is not allowed in the request body", HeadersKey)) + header, ok := header2.FromContext(ctx) + if !ok { + return nil, errors.New("error getting headers from context") } - input[HeadersKey] = ctx.Value(HeadersKey) + input[HeaderKey] = header - return input, err + return input, nil } diff --git a/internal/service/policy/service_test.go b/internal/service/policy/service_test.go index 26167735..3c4fcba1 100644 --- a/internal/service/policy/service_test.go +++ b/internal/service/policy/service_test.go @@ -2,6 +2,8 @@ package policy_test import ( "context" + "net/http" + "net/http/httptest" "testing" "time" @@ -12,6 +14,7 @@ import ( "gitlab.com/gaia-x/data-infrastructure-federation-services/tsa/golib/errors" "gitlab.com/gaia-x/data-infrastructure-federation-services/tsa/golib/ptr" goapolicy "gitlab.com/gaia-x/data-infrastructure-federation-services/tsa/policy/gen/policy" + header "gitlab.com/gaia-x/data-infrastructure-federation-services/tsa/policy/internal/middleware" "gitlab.com/gaia-x/data-infrastructure-federation-services/tsa/policy/internal/service/policy" "gitlab.com/gaia-x/data-infrastructure-federation-services/tsa/policy/internal/service/policy/policyfakes" "gitlab.com/gaia-x/data-infrastructure-federation-services/tsa/policy/internal/storage" @@ -46,14 +49,31 @@ func TestService_Evaluate(t *testing.T) { // prepare test request to be used in tests testReq := func() *goapolicy.EvaluateRequest { + input := map[string]interface{}{"msg": "yes"} + var body interface{} = input + return &goapolicy.EvaluateRequest{ Group: "testgroup", PolicyName: "example", Version: "1.0", - Input: map[string]interface{}{"msg": "yes"}, + Input: &body, + TTL: ptr.Int(30), } } + // prepare http.Request for tests + httpReq := func() *http.Request { + req := httptest.NewRequest("GET", "/", nil) + req.Header.Set("Authorization", "my-token") + return req + } + + // prepare context containing headers + ctxWithHeaders := func() context.Context { + ctx := header.ToContext(context.Background(), httpReq()) + return ctx + } + tests := []struct { // test input name string @@ -69,6 +89,7 @@ func TestService_Evaluate(t *testing.T) { }{ { name: "prepared query is found in queryCache", + ctx: ctxWithHeaders(), req: testReq(), regocache: &policyfakes.FakeRegoCache{ GetStub: func(key string) (*rego.PreparedEvalQuery, bool) { @@ -136,37 +157,9 @@ func TestService_Evaluate(t *testing.T) { errkind: errors.Forbidden, errtext: "policy is locked", }, - { - name: "policy is found in storage but request body contains forbidden key", - req: &goapolicy.EvaluateRequest{ - Group: "testgroup", - PolicyName: "example", - Version: "1.0", - Input: map[string]interface{}{"msg": "yes", policy.HeadersKey: "baz"}, - }, - regocache: &policyfakes.FakeRegoCache{ - GetStub: func(key string) (*rego.PreparedEvalQuery, bool) { - return nil, false - }, - }, - storage: &policyfakes.FakeStorage{ - PolicyStub: func(ctx context.Context, s string, s2 string, s3 string) (*storage.Policy, error) { - return &storage.Policy{ - Name: "example", - Group: "testgroup", - Version: "1.0", - Rego: testPolicy, - Locked: false, - LastUpdate: time.Now(), - }, nil - }, - }, - res: nil, - errkind: errors.BadRequest, - errtext: "error adding headers to evaluate input", - }, { name: "policy is found in storage and isn't locked", + ctx: ctxWithHeaders(), req: testReq(), regocache: &policyfakes.FakeRegoCache{ GetStub: func(key string) (*rego.PreparedEvalQuery, bool) { @@ -196,6 +189,7 @@ func TestService_Evaluate(t *testing.T) { }, { name: "policy is executed successfully, but storing the result in cache fails", + ctx: ctxWithHeaders(), req: testReq(), regocache: &policyfakes.FakeRegoCache{ GetStub: func(key string) (*rego.PreparedEvalQuery, bool) { @@ -224,6 +218,7 @@ func TestService_Evaluate(t *testing.T) { }, { name: "policy with blank variable assignment is evaluated successfully", + ctx: ctxWithHeaders(), req: testReq(), regocache: &policyfakes.FakeRegoCache{ GetStub: func(key string) (*rego.PreparedEvalQuery, bool) { @@ -253,13 +248,8 @@ func TestService_Evaluate(t *testing.T) { }, { name: "policy is evaluated successfully with TTL sent in the request headers", - req: &goapolicy.EvaluateRequest{ - Group: "testgroup", - PolicyName: "example", - Version: "1.0", - Input: map[string]interface{}{"msg": "yes"}, - TTL: ptr.Int(30), - }, + ctx: ctxWithHeaders(), + req: testReq(), regocache: &policyfakes.FakeRegoCache{ GetStub: func(key string) (*rego.PreparedEvalQuery, bool) { return nil, false @@ -288,6 +278,7 @@ func TestService_Evaluate(t *testing.T) { }, { name: "policy using static json data is evaluated successfully", + ctx: ctxWithHeaders(), req: testReq(), regocache: &policyfakes.FakeRegoCache{ GetStub: func(key string) (*rego.PreparedEvalQuery, bool) { @@ -318,7 +309,7 @@ func TestService_Evaluate(t *testing.T) { }, { name: "policy accessing headers is evaluated successfully", - ctx: context.WithValue(context.Background(), policy.HeadersKey, map[string]interface{}{"Authorization": "my-token"}), //nolint:all + ctx: ctxWithHeaders(), req: testReq(), regocache: &policyfakes.FakeRegoCache{ GetStub: func(key string) (*rego.PreparedEvalQuery, bool) { @@ -343,7 +334,7 @@ func TestService_Evaluate(t *testing.T) { }, }, res: &goapolicy.EvaluateResult{ - Result: map[string]interface{}{"token": "my-token"}, + Result: map[string]interface{}{"token": []interface{}{"my-token"}}, }, }, } -- GitLab