diff --git a/README.md b/README.md index a2fae4887d045376eb966c665fabe2928402a95f..0eee1a63bcd04015519165af84c263d72d261a9d 100644 --- a/README.md +++ b/README.md @@ -174,10 +174,6 @@ endpoints for working with arbitrary dynamically uploaded policies. HTTP Request Headers are passed to the evaluation runtime on each request. One could access any header by name within the Rego source code using `input.header.name` or `input.header["name"]`. -##### **Important:** -The key `header` is forbidden for request body on the `evaluation` endpoint. Sending a `header` key in -the request body results in `400 Bad Request` response status code and the server will not process the request. - ### Policy Extensions Functions A brief documentation for the available Rego extensions functions diff --git a/cmd/policy/main.go b/cmd/policy/main.go index 86ec19422b52a7a8c705d2bef4d4b2027c47206a..71dccc16b819b627a9f4aa572d957ee5ca1cfa69 100644 --- a/cmd/policy/main.go +++ b/cmd/policy/main.go @@ -29,6 +29,7 @@ import ( goapolicy "gitlab.com/gaia-x/data-infrastructure-federation-services/tsa/policy/gen/policy" "gitlab.com/gaia-x/data-infrastructure-federation-services/tsa/policy/internal/clients/cache" "gitlab.com/gaia-x/data-infrastructure-federation-services/tsa/policy/internal/config" + header "gitlab.com/gaia-x/data-infrastructure-federation-services/tsa/policy/internal/middleware" "gitlab.com/gaia-x/data-infrastructure-federation-services/tsa/policy/internal/regocache" "gitlab.com/gaia-x/data-infrastructure-federation-services/tsa/policy/internal/regofunc" "gitlab.com/gaia-x/data-infrastructure-federation-services/tsa/policy/internal/service" @@ -157,7 +158,7 @@ func main() { } // Apply middlewares on the servers - policyServer.Evaluate = policy.HeadersMiddleware()(policyServer.Evaluate) + policyServer.Evaluate = header.Middleware()(policyServer.Evaluate) // Configure the mux. goapolicysrv.Mount(mux, policyServer) diff --git a/internal/middleware/header.go b/internal/middleware/header.go new file mode 100644 index 0000000000000000000000000000000000000000..5c179f2bb045d6c3db4e04ae2ebe80e8719cc836 --- /dev/null +++ b/internal/middleware/header.go @@ -0,0 +1,32 @@ +package header + +import ( + "context" + "net/http" +) + +type key string + +const headerKey key = "header" + +// Middleware is an HTTP server middleware that gets all HTTP headers +// and adds them to a request context value. +func Middleware() func(http.Handler) http.Handler { + return func(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := ToContext(r.Context(), r) + req := r.WithContext(ctx) + + h.ServeHTTP(w, req) + }) + } +} + +func ToContext(ctx context.Context, r *http.Request) context.Context { + return context.WithValue(ctx, headerKey, r.Header) +} + +func FromContext(ctx context.Context) (http.Header, bool) { + header, ok := ctx.Value(headerKey).(http.Header) + return header, ok +} diff --git a/internal/service/policy/service.go b/internal/service/policy/service.go index dbc1d9f5b44b44157debef26d04c3d55266fac0a..afca866261880d6fdb52285dd197c6293f82a22e 100644 --- a/internal/service/policy/service.go +++ b/internal/service/policy/service.go @@ -4,7 +4,6 @@ import ( "context" "encoding/json" "fmt" - "net/http" "github.com/google/uuid" "github.com/open-policy-agent/opa/rego" @@ -13,6 +12,7 @@ import ( "gitlab.com/gaia-x/data-infrastructure-federation-services/tsa/golib/errors" "gitlab.com/gaia-x/data-infrastructure-federation-services/tsa/policy/gen/policy" + header2 "gitlab.com/gaia-x/data-infrastructure-federation-services/tsa/policy/internal/middleware" "gitlab.com/gaia-x/data-infrastructure-federation-services/tsa/policy/internal/regofunc" "gitlab.com/gaia-x/data-infrastructure-federation-services/tsa/policy/internal/storage" ) @@ -21,7 +21,7 @@ import ( //go:generate counterfeiter . Storage //go:generate counterfeiter . RegoCache -const HeadersKey = "header" +const HeaderKey = "header" type Cache interface { Set(ctx context.Context, key, namespace, scope string, value []byte, ttl int) error @@ -278,41 +278,23 @@ func (s *Service) queryCacheKey(group, policyName, version string) string { return fmt.Sprintf("%s,%s,%s", group, policyName, version) } -// HeadersMiddleware is an HTTP server middleware that gets all HTTP headers -// and adds them to a request context value. -func HeadersMiddleware() func(http.Handler) http.Handler { - return func(h http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - headers := map[string]string{} - - // get all http headers and add them to a newly initialized request context. - for name := range r.Header { - headers[name] = r.Header.Get(name) - } - ctx := context.WithValue(r.Context(), HeadersKey, headers) //nolint:all - req := r.WithContext(ctx) - - // call initial handler. - h.ServeHTTP(w, req) - }) - } -} - -func (s *Service) addHeadersToEvaluateInput(ctx context.Context, req *policy.EvaluateRequest) (interface{}, error) { - bytes, err := json.Marshal(req.Input) - if err != nil { - return nil, err +func (s *Service) addHeadersToEvaluateInput(ctx context.Context, req *policy.EvaluateRequest) (map[string]interface{}, error) { + i, ok := req.Input.(*interface{}) + if !ok { + return nil, errors.New("unexpected request body: unsuccessful casting to interface") } - var input map[string]interface{} - if err := json.Unmarshal(bytes, &input); err != nil { - return nil, err + i2 := *i + input, ok := i2.(map[string]interface{}) + if !ok { + return nil, errors.New("unexpected request body: unsuccessful casting to map") } - if _, ok := input[HeadersKey]; ok { - return nil, errors.New(errors.BadRequest, fmt.Sprintf("key `%s` is not allowed in the request body", HeadersKey)) + header, ok := header2.FromContext(ctx) + if !ok { + return nil, errors.New("error getting headers from context") } - input[HeadersKey] = ctx.Value(HeadersKey) + input[HeaderKey] = header - return input, err + return input, nil } diff --git a/internal/service/policy/service_test.go b/internal/service/policy/service_test.go index 26167735bf5579024ebd548dfec970d2b9354409..3c4fcba1415c4b3a2f212afaae89ad4452d06de6 100644 --- a/internal/service/policy/service_test.go +++ b/internal/service/policy/service_test.go @@ -2,6 +2,8 @@ package policy_test import ( "context" + "net/http" + "net/http/httptest" "testing" "time" @@ -12,6 +14,7 @@ import ( "gitlab.com/gaia-x/data-infrastructure-federation-services/tsa/golib/errors" "gitlab.com/gaia-x/data-infrastructure-federation-services/tsa/golib/ptr" goapolicy "gitlab.com/gaia-x/data-infrastructure-federation-services/tsa/policy/gen/policy" + header "gitlab.com/gaia-x/data-infrastructure-federation-services/tsa/policy/internal/middleware" "gitlab.com/gaia-x/data-infrastructure-federation-services/tsa/policy/internal/service/policy" "gitlab.com/gaia-x/data-infrastructure-federation-services/tsa/policy/internal/service/policy/policyfakes" "gitlab.com/gaia-x/data-infrastructure-federation-services/tsa/policy/internal/storage" @@ -46,14 +49,31 @@ func TestService_Evaluate(t *testing.T) { // prepare test request to be used in tests testReq := func() *goapolicy.EvaluateRequest { + input := map[string]interface{}{"msg": "yes"} + var body interface{} = input + return &goapolicy.EvaluateRequest{ Group: "testgroup", PolicyName: "example", Version: "1.0", - Input: map[string]interface{}{"msg": "yes"}, + Input: &body, + TTL: ptr.Int(30), } } + // prepare http.Request for tests + httpReq := func() *http.Request { + req := httptest.NewRequest("GET", "/", nil) + req.Header.Set("Authorization", "my-token") + return req + } + + // prepare context containing headers + ctxWithHeaders := func() context.Context { + ctx := header.ToContext(context.Background(), httpReq()) + return ctx + } + tests := []struct { // test input name string @@ -69,6 +89,7 @@ func TestService_Evaluate(t *testing.T) { }{ { name: "prepared query is found in queryCache", + ctx: ctxWithHeaders(), req: testReq(), regocache: &policyfakes.FakeRegoCache{ GetStub: func(key string) (*rego.PreparedEvalQuery, bool) { @@ -136,37 +157,9 @@ func TestService_Evaluate(t *testing.T) { errkind: errors.Forbidden, errtext: "policy is locked", }, - { - name: "policy is found in storage but request body contains forbidden key", - req: &goapolicy.EvaluateRequest{ - Group: "testgroup", - PolicyName: "example", - Version: "1.0", - Input: map[string]interface{}{"msg": "yes", policy.HeadersKey: "baz"}, - }, - regocache: &policyfakes.FakeRegoCache{ - GetStub: func(key string) (*rego.PreparedEvalQuery, bool) { - return nil, false - }, - }, - storage: &policyfakes.FakeStorage{ - PolicyStub: func(ctx context.Context, s string, s2 string, s3 string) (*storage.Policy, error) { - return &storage.Policy{ - Name: "example", - Group: "testgroup", - Version: "1.0", - Rego: testPolicy, - Locked: false, - LastUpdate: time.Now(), - }, nil - }, - }, - res: nil, - errkind: errors.BadRequest, - errtext: "error adding headers to evaluate input", - }, { name: "policy is found in storage and isn't locked", + ctx: ctxWithHeaders(), req: testReq(), regocache: &policyfakes.FakeRegoCache{ GetStub: func(key string) (*rego.PreparedEvalQuery, bool) { @@ -196,6 +189,7 @@ func TestService_Evaluate(t *testing.T) { }, { name: "policy is executed successfully, but storing the result in cache fails", + ctx: ctxWithHeaders(), req: testReq(), regocache: &policyfakes.FakeRegoCache{ GetStub: func(key string) (*rego.PreparedEvalQuery, bool) { @@ -224,6 +218,7 @@ func TestService_Evaluate(t *testing.T) { }, { name: "policy with blank variable assignment is evaluated successfully", + ctx: ctxWithHeaders(), req: testReq(), regocache: &policyfakes.FakeRegoCache{ GetStub: func(key string) (*rego.PreparedEvalQuery, bool) { @@ -253,13 +248,8 @@ func TestService_Evaluate(t *testing.T) { }, { name: "policy is evaluated successfully with TTL sent in the request headers", - req: &goapolicy.EvaluateRequest{ - Group: "testgroup", - PolicyName: "example", - Version: "1.0", - Input: map[string]interface{}{"msg": "yes"}, - TTL: ptr.Int(30), - }, + ctx: ctxWithHeaders(), + req: testReq(), regocache: &policyfakes.FakeRegoCache{ GetStub: func(key string) (*rego.PreparedEvalQuery, bool) { return nil, false @@ -288,6 +278,7 @@ func TestService_Evaluate(t *testing.T) { }, { name: "policy using static json data is evaluated successfully", + ctx: ctxWithHeaders(), req: testReq(), regocache: &policyfakes.FakeRegoCache{ GetStub: func(key string) (*rego.PreparedEvalQuery, bool) { @@ -318,7 +309,7 @@ func TestService_Evaluate(t *testing.T) { }, { name: "policy accessing headers is evaluated successfully", - ctx: context.WithValue(context.Background(), policy.HeadersKey, map[string]interface{}{"Authorization": "my-token"}), //nolint:all + ctx: ctxWithHeaders(), req: testReq(), regocache: &policyfakes.FakeRegoCache{ GetStub: func(key string) (*rego.PreparedEvalQuery, bool) { @@ -343,7 +334,7 @@ func TestService_Evaluate(t *testing.T) { }, }, res: &goapolicy.EvaluateResult{ - Result: map[string]interface{}{"token": "my-token"}, + Result: map[string]interface{}{"token": []interface{}{"my-token"}}, }, }, }