diff --git a/README.md b/README.md index 8ebd4760568379c507ec3f3eebd0b21fa15b3916..0eee1a63bcd04015519165af84c263d72d261a9d 100644 --- a/README.md +++ b/README.md @@ -169,6 +169,11 @@ can be mapped and used for evaluating all kinds of different policies. Without a package naming rule, there's no way the service can automatically generate HTTP endpoints for working with arbitrary dynamically uploaded policies. +### Access HTTP Headers inside a policy + +HTTP Request Headers are passed to the evaluation runtime on each request. One could access any header by name within +the Rego source code using `input.header.name` or `input.header["name"]`. + ### Policy Extensions Functions A brief documentation for the available Rego extensions functions diff --git a/cmd/policy/main.go b/cmd/policy/main.go index 48d2d9f8afcc6fdeb057a08f51c78bcf6f6c7084..f941473116eae58d78626bab031c280dd8e1846f 100644 --- a/cmd/policy/main.go +++ b/cmd/policy/main.go @@ -29,6 +29,7 @@ import ( goapolicy "gitlab.com/gaia-x/data-infrastructure-federation-services/tsa/policy/gen/policy" "gitlab.com/gaia-x/data-infrastructure-federation-services/tsa/policy/internal/clients/cache" "gitlab.com/gaia-x/data-infrastructure-federation-services/tsa/policy/internal/config" + "gitlab.com/gaia-x/data-infrastructure-federation-services/tsa/policy/internal/header" "gitlab.com/gaia-x/data-infrastructure-federation-services/tsa/policy/internal/regocache" "gitlab.com/gaia-x/data-infrastructure-federation-services/tsa/policy/internal/regofunc" "gitlab.com/gaia-x/data-infrastructure-federation-services/tsa/policy/internal/service" @@ -156,6 +157,9 @@ func main() { openapiServer = goaopenapisrv.New(openapiEndpoints, mux, dec, enc, nil, errFormatter, nil, nil) } + // Apply middlewares on the servers + policyServer.Evaluate = header.Middleware()(policyServer.Evaluate) + // Configure the mux. goapolicysrv.Mount(mux, policyServer) goahealthsrv.Mount(mux, healthServer) diff --git a/internal/header/header.go b/internal/header/header.go new file mode 100644 index 0000000000000000000000000000000000000000..5c179f2bb045d6c3db4e04ae2ebe80e8719cc836 --- /dev/null +++ b/internal/header/header.go @@ -0,0 +1,32 @@ +package header + +import ( + "context" + "net/http" +) + +type key string + +const headerKey key = "header" + +// Middleware is an HTTP server middleware that gets all HTTP headers +// and adds them to a request context value. +func Middleware() func(http.Handler) http.Handler { + return func(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := ToContext(r.Context(), r) + req := r.WithContext(ctx) + + h.ServeHTTP(w, req) + }) + } +} + +func ToContext(ctx context.Context, r *http.Request) context.Context { + return context.WithValue(ctx, headerKey, r.Header) +} + +func FromContext(ctx context.Context) (http.Header, bool) { + header, ok := ctx.Value(headerKey).(http.Header) + return header, ok +} diff --git a/internal/header/header_test.go b/internal/header/header_test.go new file mode 100644 index 0000000000000000000000000000000000000000..6f4e1ce3413aec9a59c16a4073552c858155694f --- /dev/null +++ b/internal/header/header_test.go @@ -0,0 +1,27 @@ +package header_test + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "gitlab.com/gaia-x/data-infrastructure-federation-services/tsa/policy/internal/header" +) + +func TestMiddleware(t *testing.T) { + expected := http.Header{"Authorization": []string{"my-token"}} + + req := httptest.NewRequest("POST", "/example", nil) + req.Header = expected + + nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + value, ok := header.FromContext(r.Context()) + assert.True(t, ok) + assert.Equal(t, expected, value) + }) + + middleware := header.Middleware() + handlerToTest := middleware(nextHandler) + handlerToTest.ServeHTTP(httptest.NewRecorder(), req) +} diff --git a/internal/service/policy/service.go b/internal/service/policy/service.go index e6cd92ae23ba29cb5b0af3d43a8514a2ff9796b0..e76c0b6aa8f993255eb1eace345d6ef640f7a413 100644 --- a/internal/service/policy/service.go +++ b/internal/service/policy/service.go @@ -12,6 +12,7 @@ import ( "gitlab.com/gaia-x/data-infrastructure-federation-services/tsa/golib/errors" "gitlab.com/gaia-x/data-infrastructure-federation-services/tsa/policy/gen/policy" + "gitlab.com/gaia-x/data-infrastructure-federation-services/tsa/policy/internal/header" "gitlab.com/gaia-x/data-infrastructure-federation-services/tsa/policy/internal/regofunc" "gitlab.com/gaia-x/data-infrastructure-federation-services/tsa/policy/internal/storage" ) @@ -20,6 +21,8 @@ import ( //go:generate counterfeiter . Storage //go:generate counterfeiter . RegoCache +const HeaderKey = "header" + type Cache interface { Set(ctx context.Context, key, namespace, scope string, value []byte, ttl int) error Get(ctx context.Context, key, namespace, scope string) ([]byte, error) @@ -80,7 +83,14 @@ func (s *Service) Evaluate(ctx context.Context, req *policy.EvaluateRequest) (*p return nil, errors.New("error evaluating policy", err) } - resultSet, err := query.Eval(ctx, rego.EvalInput(req.Input)) + // add headers to the request input + input, err := s.addHeadersToEvaluateInput(ctx, req.Input) + if err != nil { + logger.Error("error adding headers to evaluate input", zap.Error(err)) + return nil, errors.New("error adding headers to evaluate input", err) + } + + resultSet, err := query.Eval(ctx, rego.EvalInput(input)) if err != nil { logger.Error("error evaluating rego query", zap.Error(err)) return nil, errors.New("error evaluating rego query", err) @@ -267,3 +277,29 @@ func (s *Service) buildRegoArgs(filename, regoPolicy, regoQuery, regoData string func (s *Service) queryCacheKey(group, policyName, version string) string { return fmt.Sprintf("%s,%s,%s", group, policyName, version) } + +func (s *Service) addHeadersToEvaluateInput(ctx context.Context, in interface{}) (map[string]interface{}, error) { + // goa framework decodes the body of the request into a pointer to interface + // for this reason we cast it first to interface pointer and then to map, which is the expected value + i, ok := in.(*interface{}) + if !ok { + return nil, errors.New("unexpected request body: unsuccessful casting to interface") + } + + i2 := *i + if i2 == nil { // no request body + i2 = map[string]interface{}{} + } + input, ok := i2.(map[string]interface{}) + if !ok { + return nil, errors.New("unexpected request body: unsuccessful casting to map") + } + + header, ok := header.FromContext(ctx) + if !ok { + return nil, errors.New("error getting headers from context") + } + input[HeaderKey] = header + + return input, nil +} diff --git a/internal/service/policy/service_test.go b/internal/service/policy/service_test.go index a239e5e7187f82f5e7a472a8efa2154918ca2c01..e25f01615ace12067b7cf7cdbeef7acc481e1fda 100644 --- a/internal/service/policy/service_test.go +++ b/internal/service/policy/service_test.go @@ -2,6 +2,8 @@ package policy_test import ( "context" + "net/http" + "net/http/httptest" "testing" "time" @@ -12,6 +14,7 @@ import ( "gitlab.com/gaia-x/data-infrastructure-federation-services/tsa/golib/errors" "gitlab.com/gaia-x/data-infrastructure-federation-services/tsa/golib/ptr" goapolicy "gitlab.com/gaia-x/data-infrastructure-federation-services/tsa/policy/gen/policy" + "gitlab.com/gaia-x/data-infrastructure-federation-services/tsa/policy/internal/header" "gitlab.com/gaia-x/data-infrastructure-federation-services/tsa/policy/internal/service/policy" "gitlab.com/gaia-x/data-infrastructure-federation-services/tsa/policy/internal/service/policy/policyfakes" "gitlab.com/gaia-x/data-infrastructure-federation-services/tsa/policy/internal/storage" @@ -33,6 +36,9 @@ func TestService_Evaluate(t *testing.T) { // prepare test policy using static json data during evaluation testPolicyWithStaticData := `package testgroup.example default allow = false allow { data.msg == "hello world" }` + // prepare test policy accessing headers during evaluation + testPolicyAccessingHeaders := `package testgroup.example token := input.header["Authorization"]` + // prepare test query that can be retrieved from rego queryCache testQuery, err := rego.New( rego.Module("example.rego", testPolicy), @@ -43,17 +49,48 @@ func TestService_Evaluate(t *testing.T) { // prepare test request to be used in tests testReq := func() *goapolicy.EvaluateRequest { + input := map[string]interface{}{"msg": "yes"} + var body interface{} = input + + return &goapolicy.EvaluateRequest{ + Group: "testgroup", + PolicyName: "example", + Version: "1.0", + Input: &body, + TTL: ptr.Int(30), + } + } + + // prepare test request with empty body + testEmptyReq := func() *goapolicy.EvaluateRequest { + var body interface{} = nil + return &goapolicy.EvaluateRequest{ Group: "testgroup", PolicyName: "example", Version: "1.0", - Input: map[string]interface{}{"msg": "yes"}, + Input: &body, + TTL: ptr.Int(30), } } + // prepare http.Request for tests + httpReq := func() *http.Request { + req := httptest.NewRequest("GET", "/", nil) + req.Header.Set("Authorization", "my-token") + return req + } + + // prepare context containing headers + ctxWithHeaders := func() context.Context { + ctx := header.ToContext(context.Background(), httpReq()) + return ctx + } + tests := []struct { // test input name string + ctx context.Context req *goapolicy.EvaluateRequest storage policy.Storage regocache policy.RegoCache @@ -65,6 +102,7 @@ func TestService_Evaluate(t *testing.T) { }{ { name: "prepared query is found in queryCache", + ctx: ctxWithHeaders(), req: testReq(), regocache: &policyfakes.FakeRegoCache{ GetStub: func(key string) (*rego.PreparedEvalQuery, bool) { @@ -134,6 +172,7 @@ func TestService_Evaluate(t *testing.T) { }, { name: "policy is found in storage and isn't locked", + ctx: ctxWithHeaders(), req: testReq(), regocache: &policyfakes.FakeRegoCache{ GetStub: func(key string) (*rego.PreparedEvalQuery, bool) { @@ -163,6 +202,7 @@ func TestService_Evaluate(t *testing.T) { }, { name: "policy is executed successfully, but storing the result in cache fails", + ctx: ctxWithHeaders(), req: testReq(), regocache: &policyfakes.FakeRegoCache{ GetStub: func(key string) (*rego.PreparedEvalQuery, bool) { @@ -191,6 +231,7 @@ func TestService_Evaluate(t *testing.T) { }, { name: "policy with blank variable assignment is evaluated successfully", + ctx: ctxWithHeaders(), req: testReq(), regocache: &policyfakes.FakeRegoCache{ GetStub: func(key string) (*rego.PreparedEvalQuery, bool) { @@ -220,13 +261,8 @@ func TestService_Evaluate(t *testing.T) { }, { name: "policy is evaluated successfully with TTL sent in the request headers", - req: &goapolicy.EvaluateRequest{ - Group: "testgroup", - PolicyName: "example", - Version: "1.0", - Input: map[string]interface{}{"msg": "yes"}, - TTL: ptr.Int(30), - }, + ctx: ctxWithHeaders(), + req: testReq(), regocache: &policyfakes.FakeRegoCache{ GetStub: func(key string) (*rego.PreparedEvalQuery, bool) { return nil, false @@ -255,6 +291,7 @@ func TestService_Evaluate(t *testing.T) { }, { name: "policy using static json data is evaluated successfully", + ctx: ctxWithHeaders(), req: testReq(), regocache: &policyfakes.FakeRegoCache{ GetStub: func(key string) (*rego.PreparedEvalQuery, bool) { @@ -283,12 +320,76 @@ func TestService_Evaluate(t *testing.T) { Result: map[string]interface{}{"allow": true}, }, }, + { + name: "policy accessing headers is evaluated successfully", + ctx: ctxWithHeaders(), + req: testReq(), + regocache: &policyfakes.FakeRegoCache{ + GetStub: func(key string) (*rego.PreparedEvalQuery, bool) { + return nil, false + }, + }, + storage: &policyfakes.FakeStorage{ + PolicyStub: func(ctx context.Context, s string, s2 string, s3 string) (*storage.Policy, error) { + return &storage.Policy{ + Name: "example", + Group: "testgroup", + Version: "1.0", + Rego: testPolicyAccessingHeaders, + Locked: false, + LastUpdate: time.Now(), + }, nil + }, + }, + cache: &policyfakes.FakeCache{ + SetStub: func(ctx context.Context, s string, s2 string, s3 string, bytes []byte, i int) error { + return nil + }, + }, + res: &goapolicy.EvaluateResult{ + Result: map[string]interface{}{"token": []interface{}{"my-token"}}, + }, + }, + { + name: "policy with empty input is evaluated successfully", + ctx: ctxWithHeaders(), + req: testEmptyReq(), + regocache: &policyfakes.FakeRegoCache{ + GetStub: func(key string) (*rego.PreparedEvalQuery, bool) { + return nil, false + }, + }, + storage: &policyfakes.FakeStorage{ + PolicyStub: func(ctx context.Context, s string, s2 string, s3 string) (*storage.Policy, error) { + return &storage.Policy{ + Name: "example", + Group: "testgroup", + Version: "1.0", + Rego: testPolicy, + Locked: false, + LastUpdate: time.Now(), + }, nil + }, + }, + cache: &policyfakes.FakeCache{ + SetStub: func(ctx context.Context, s string, s2 string, s3 string, bytes []byte, i int) error { + return nil + }, + }, + res: &goapolicy.EvaluateResult{ + Result: map[string]interface{}{"allow": false}, + }, + }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { svc := policy.New(test.storage, test.regocache, test.cache, zap.NewNop()) - res, err := svc.Evaluate(context.Background(), test.req) + ctx := context.Background() + if test.ctx != nil { + ctx = test.ctx + } + res, err := svc.Evaluate(ctx, test.req) if err == nil { assert.Empty(t, test.errtext) assert.NotNil(t, res)