Skip to content
Snippets Groups Projects
Commit ce3b97fd authored by Yordan Kinkov's avatar Yordan Kinkov
Browse files

Move header middleware to its own package

parent 1ab7fc5b
No related branches found
No related tags found
No related merge requests found
......@@ -174,10 +174,6 @@ endpoints for working with arbitrary dynamically uploaded policies.
HTTP Request Headers are passed to the evaluation runtime on each request. One could access any header by name within
the Rego source code using `input.header.name` or `input.header["name"]`.
##### **Important:**
The key `header` is forbidden for request body on the `evaluation` endpoint. Sending a `header` key in
the request body results in `400 Bad Request` response status code and the server will not process the request.
### Policy Extensions Functions
A brief documentation for the available Rego extensions functions
......
......@@ -29,6 +29,7 @@ import (
goapolicy "gitlab.com/gaia-x/data-infrastructure-federation-services/tsa/policy/gen/policy"
"gitlab.com/gaia-x/data-infrastructure-federation-services/tsa/policy/internal/clients/cache"
"gitlab.com/gaia-x/data-infrastructure-federation-services/tsa/policy/internal/config"
header "gitlab.com/gaia-x/data-infrastructure-federation-services/tsa/policy/internal/middleware"
"gitlab.com/gaia-x/data-infrastructure-federation-services/tsa/policy/internal/regocache"
"gitlab.com/gaia-x/data-infrastructure-federation-services/tsa/policy/internal/regofunc"
"gitlab.com/gaia-x/data-infrastructure-federation-services/tsa/policy/internal/service"
......@@ -157,7 +158,7 @@ func main() {
}
// Apply middlewares on the servers
policyServer.Evaluate = policy.HeadersMiddleware()(policyServer.Evaluate)
policyServer.Evaluate = header.Middleware()(policyServer.Evaluate)
// Configure the mux.
goapolicysrv.Mount(mux, policyServer)
......
package header
import (
"context"
"net/http"
)
type key string
const headerKey key = "header"
// Middleware is an HTTP server middleware that gets all HTTP headers
// and adds them to a request context value.
func Middleware() func(http.Handler) http.Handler {
return func(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := ToContext(r.Context(), r)
req := r.WithContext(ctx)
h.ServeHTTP(w, req)
})
}
}
func ToContext(ctx context.Context, r *http.Request) context.Context {
return context.WithValue(ctx, headerKey, r.Header)
}
func FromContext(ctx context.Context) (http.Header, bool) {
header, ok := ctx.Value(headerKey).(http.Header)
return header, ok
}
......@@ -4,7 +4,6 @@ import (
"context"
"encoding/json"
"fmt"
"net/http"
"github.com/google/uuid"
"github.com/open-policy-agent/opa/rego"
......@@ -13,6 +12,7 @@ import (
"gitlab.com/gaia-x/data-infrastructure-federation-services/tsa/golib/errors"
"gitlab.com/gaia-x/data-infrastructure-federation-services/tsa/policy/gen/policy"
header2 "gitlab.com/gaia-x/data-infrastructure-federation-services/tsa/policy/internal/middleware"
"gitlab.com/gaia-x/data-infrastructure-federation-services/tsa/policy/internal/regofunc"
"gitlab.com/gaia-x/data-infrastructure-federation-services/tsa/policy/internal/storage"
)
......@@ -21,7 +21,7 @@ import (
//go:generate counterfeiter . Storage
//go:generate counterfeiter . RegoCache
const HeadersKey = "header"
const HeaderKey = "header"
type Cache interface {
Set(ctx context.Context, key, namespace, scope string, value []byte, ttl int) error
......@@ -278,41 +278,23 @@ func (s *Service) queryCacheKey(group, policyName, version string) string {
return fmt.Sprintf("%s,%s,%s", group, policyName, version)
}
// HeadersMiddleware is an HTTP server middleware that gets all HTTP headers
// and adds them to a request context value.
func HeadersMiddleware() func(http.Handler) http.Handler {
return func(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
headers := map[string]string{}
// get all http headers and add them to a newly initialized request context.
for name := range r.Header {
headers[name] = r.Header.Get(name)
}
ctx := context.WithValue(r.Context(), HeadersKey, headers) //nolint:all
req := r.WithContext(ctx)
// call initial handler.
h.ServeHTTP(w, req)
})
}
}
func (s *Service) addHeadersToEvaluateInput(ctx context.Context, req *policy.EvaluateRequest) (interface{}, error) {
bytes, err := json.Marshal(req.Input)
if err != nil {
return nil, err
func (s *Service) addHeadersToEvaluateInput(ctx context.Context, req *policy.EvaluateRequest) (map[string]interface{}, error) {
i, ok := req.Input.(*interface{})
if !ok {
return nil, errors.New("unexpected request body: unsuccessful casting to interface")
}
var input map[string]interface{}
if err := json.Unmarshal(bytes, &input); err != nil {
return nil, err
i2 := *i
input, ok := i2.(map[string]interface{})
if !ok {
return nil, errors.New("unexpected request body: unsuccessful casting to map")
}
if _, ok := input[HeadersKey]; ok {
return nil, errors.New(errors.BadRequest, fmt.Sprintf("key `%s` is not allowed in the request body", HeadersKey))
header, ok := header2.FromContext(ctx)
if !ok {
return nil, errors.New("error getting headers from context")
}
input[HeadersKey] = ctx.Value(HeadersKey)
input[HeaderKey] = header
return input, err
return input, nil
}
......@@ -2,6 +2,8 @@ package policy_test
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"time"
......@@ -12,6 +14,7 @@ import (
"gitlab.com/gaia-x/data-infrastructure-federation-services/tsa/golib/errors"
"gitlab.com/gaia-x/data-infrastructure-federation-services/tsa/golib/ptr"
goapolicy "gitlab.com/gaia-x/data-infrastructure-federation-services/tsa/policy/gen/policy"
header "gitlab.com/gaia-x/data-infrastructure-federation-services/tsa/policy/internal/middleware"
"gitlab.com/gaia-x/data-infrastructure-federation-services/tsa/policy/internal/service/policy"
"gitlab.com/gaia-x/data-infrastructure-federation-services/tsa/policy/internal/service/policy/policyfakes"
"gitlab.com/gaia-x/data-infrastructure-federation-services/tsa/policy/internal/storage"
......@@ -46,14 +49,31 @@ func TestService_Evaluate(t *testing.T) {
// prepare test request to be used in tests
testReq := func() *goapolicy.EvaluateRequest {
input := map[string]interface{}{"msg": "yes"}
var body interface{} = input
return &goapolicy.EvaluateRequest{
Group: "testgroup",
PolicyName: "example",
Version: "1.0",
Input: map[string]interface{}{"msg": "yes"},
Input: &body,
TTL: ptr.Int(30),
}
}
// prepare http.Request for tests
httpReq := func() *http.Request {
req := httptest.NewRequest("GET", "/", nil)
req.Header.Set("Authorization", "my-token")
return req
}
// prepare context containing headers
ctxWithHeaders := func() context.Context {
ctx := header.ToContext(context.Background(), httpReq())
return ctx
}
tests := []struct {
// test input
name string
......@@ -69,6 +89,7 @@ func TestService_Evaluate(t *testing.T) {
}{
{
name: "prepared query is found in queryCache",
ctx: ctxWithHeaders(),
req: testReq(),
regocache: &policyfakes.FakeRegoCache{
GetStub: func(key string) (*rego.PreparedEvalQuery, bool) {
......@@ -136,37 +157,9 @@ func TestService_Evaluate(t *testing.T) {
errkind: errors.Forbidden,
errtext: "policy is locked",
},
{
name: "policy is found in storage but request body contains forbidden key",
req: &goapolicy.EvaluateRequest{
Group: "testgroup",
PolicyName: "example",
Version: "1.0",
Input: map[string]interface{}{"msg": "yes", policy.HeadersKey: "baz"},
},
regocache: &policyfakes.FakeRegoCache{
GetStub: func(key string) (*rego.PreparedEvalQuery, bool) {
return nil, false
},
},
storage: &policyfakes.FakeStorage{
PolicyStub: func(ctx context.Context, s string, s2 string, s3 string) (*storage.Policy, error) {
return &storage.Policy{
Name: "example",
Group: "testgroup",
Version: "1.0",
Rego: testPolicy,
Locked: false,
LastUpdate: time.Now(),
}, nil
},
},
res: nil,
errkind: errors.BadRequest,
errtext: "error adding headers to evaluate input",
},
{
name: "policy is found in storage and isn't locked",
ctx: ctxWithHeaders(),
req: testReq(),
regocache: &policyfakes.FakeRegoCache{
GetStub: func(key string) (*rego.PreparedEvalQuery, bool) {
......@@ -196,6 +189,7 @@ func TestService_Evaluate(t *testing.T) {
},
{
name: "policy is executed successfully, but storing the result in cache fails",
ctx: ctxWithHeaders(),
req: testReq(),
regocache: &policyfakes.FakeRegoCache{
GetStub: func(key string) (*rego.PreparedEvalQuery, bool) {
......@@ -224,6 +218,7 @@ func TestService_Evaluate(t *testing.T) {
},
{
name: "policy with blank variable assignment is evaluated successfully",
ctx: ctxWithHeaders(),
req: testReq(),
regocache: &policyfakes.FakeRegoCache{
GetStub: func(key string) (*rego.PreparedEvalQuery, bool) {
......@@ -253,13 +248,8 @@ func TestService_Evaluate(t *testing.T) {
},
{
name: "policy is evaluated successfully with TTL sent in the request headers",
req: &goapolicy.EvaluateRequest{
Group: "testgroup",
PolicyName: "example",
Version: "1.0",
Input: map[string]interface{}{"msg": "yes"},
TTL: ptr.Int(30),
},
ctx: ctxWithHeaders(),
req: testReq(),
regocache: &policyfakes.FakeRegoCache{
GetStub: func(key string) (*rego.PreparedEvalQuery, bool) {
return nil, false
......@@ -288,6 +278,7 @@ func TestService_Evaluate(t *testing.T) {
},
{
name: "policy using static json data is evaluated successfully",
ctx: ctxWithHeaders(),
req: testReq(),
regocache: &policyfakes.FakeRegoCache{
GetStub: func(key string) (*rego.PreparedEvalQuery, bool) {
......@@ -318,7 +309,7 @@ func TestService_Evaluate(t *testing.T) {
},
{
name: "policy accessing headers is evaluated successfully",
ctx: context.WithValue(context.Background(), policy.HeadersKey, map[string]interface{}{"Authorization": "my-token"}), //nolint:all
ctx: ctxWithHeaders(),
req: testReq(),
regocache: &policyfakes.FakeRegoCache{
GetStub: func(key string) (*rego.PreparedEvalQuery, bool) {
......@@ -343,7 +334,7 @@ func TestService_Evaluate(t *testing.T) {
},
},
res: &goapolicy.EvaluateResult{
Result: map[string]interface{}{"token": "my-token"},
Result: map[string]interface{}{"token": []interface{}{"my-token"}},
},
},
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment