Skip to content
Snippets Groups Projects
Commit 5bc64389 authored by Lyuben Penkovski's avatar Lyuben Penkovski
Browse files

Store policy evaluation result in the Cache service

parent 62efad2d
No related branches found
No related tags found
1 merge request!16Store all policy evaluation results in the cache with unique evaluation ID
Pipeline #50902 passed
...@@ -25,6 +25,7 @@ import ( ...@@ -25,6 +25,7 @@ import (
goapolicysrv "code.vereign.com/gaiax/tsa/policy/gen/http/policy/server" goapolicysrv "code.vereign.com/gaiax/tsa/policy/gen/http/policy/server"
"code.vereign.com/gaiax/tsa/policy/gen/openapi" "code.vereign.com/gaiax/tsa/policy/gen/openapi"
goapolicy "code.vereign.com/gaiax/tsa/policy/gen/policy" goapolicy "code.vereign.com/gaiax/tsa/policy/gen/policy"
"code.vereign.com/gaiax/tsa/policy/internal/clients/cache"
"code.vereign.com/gaiax/tsa/policy/internal/config" "code.vereign.com/gaiax/tsa/policy/internal/config"
"code.vereign.com/gaiax/tsa/policy/internal/regocache" "code.vereign.com/gaiax/tsa/policy/internal/regocache"
"code.vereign.com/gaiax/tsa/policy/internal/regofunc" "code.vereign.com/gaiax/tsa/policy/internal/regofunc"
...@@ -65,6 +66,8 @@ func main() { ...@@ -65,6 +66,8 @@ func main() {
} }
defer db.Disconnect(context.Background()) //nolint:errcheck defer db.Disconnect(context.Background()) //nolint:errcheck
httpClient := httpClient()
// create storage // create storage
storage, err := storage.New(db, cfg.Mongo.DB, cfg.Mongo.Collection, logger) storage, err := storage.New(db, cfg.Mongo.DB, cfg.Mongo.Collection, logger)
if err != nil { if err != nil {
...@@ -76,7 +79,6 @@ func main() { ...@@ -76,7 +79,6 @@ func main() {
// register rego extension functions // register rego extension functions
{ {
httpClient := httpClient()
cacheFuncs := regofunc.NewCacheFuncs(cfg.Cache.Addr, httpClient) cacheFuncs := regofunc.NewCacheFuncs(cfg.Cache.Addr, httpClient)
didResolverFuncs := regofunc.NewDIDResolverFuncs(cfg.DIDResolver.Addr, httpClient) didResolverFuncs := regofunc.NewDIDResolverFuncs(cfg.DIDResolver.Addr, httpClient)
taskFuncs := regofunc.NewTaskFuncs(cfg.Task.Addr, httpClient) taskFuncs := regofunc.NewTaskFuncs(cfg.Task.Addr, httpClient)
...@@ -90,13 +92,16 @@ func main() { ...@@ -90,13 +92,16 @@ func main() {
// subscribe the cache for policy data changes // subscribe the cache for policy data changes
storage.AddPolicyChangeSubscriber(regocache) storage.AddPolicyChangeSubscriber(regocache)
// create cache client
cache := cache.New(cfg.Cache.Addr, cache.WithHTTPClient(httpClient))
// create services // create services
var ( var (
policySvc goapolicy.Service policySvc goapolicy.Service
healthSvc goahealth.Service healthSvc goahealth.Service
) )
{ {
policySvc = policy.New(storage, regocache, logger) policySvc = policy.New(storage, regocache, cache, logger)
healthSvc = health.New() healthSvc = health.New()
} }
......
...@@ -4,6 +4,7 @@ go 1.17 ...@@ -4,6 +4,7 @@ go 1.17
require ( require (
code.vereign.com/gaiax/tsa/golib v0.0.0-20220321093827-5fdf8f34aad9 code.vereign.com/gaiax/tsa/golib v0.0.0-20220321093827-5fdf8f34aad9
github.com/google/uuid v1.3.0
github.com/kelseyhightower/envconfig v1.4.0 github.com/kelseyhightower/envconfig v1.4.0
github.com/open-policy-agent/opa v0.38.1 github.com/open-policy-agent/opa v0.38.1
github.com/stretchr/testify v1.7.0 github.com/stretchr/testify v1.7.0
...@@ -22,7 +23,6 @@ require ( ...@@ -22,7 +23,6 @@ require (
github.com/go-stack/stack v1.8.0 // indirect github.com/go-stack/stack v1.8.0 // indirect
github.com/gobwas/glob v0.2.3 // indirect github.com/gobwas/glob v0.2.3 // indirect
github.com/golang/snappy v0.0.4 // indirect github.com/golang/snappy v0.0.4 // indirect
github.com/google/uuid v1.3.0 // indirect
github.com/gopherjs/gopherjs v0.0.0-20220221023154-0b2280d3ff96 // indirect github.com/gopherjs/gopherjs v0.0.0-20220221023154-0b2280d3ff96 // indirect
github.com/gorilla/websocket v1.5.0 // indirect github.com/gorilla/websocket v1.5.0 // indirect
github.com/jtolds/gls v4.20.0+incompatible // indirect github.com/jtolds/gls v4.20.0+incompatible // indirect
......
package cache
import (
"bytes"
"context"
"fmt"
"io"
"net/http"
"code.vereign.com/gaiax/tsa/golib/errors"
)
// Client for the Cache service.
type Client struct {
addr string
httpClient *http.Client
}
func New(addr string, opts ...Option) *Client {
c := &Client{
addr: addr,
httpClient: http.DefaultClient,
}
for _, opt := range opts {
opt(c)
}
return c
}
func (c *Client) Set(ctx context.Context, key, namespace, scope string, value []byte) error {
req, err := http.NewRequestWithContext(ctx, "POST", c.addr+"/v1/cache", bytes.NewReader(value))
if err != nil {
return err
}
req.Header = http.Header{
"x-cache-key": []string{key},
"x-cache-namespace": []string{namespace},
"x-cache-scope": []string{scope},
}
resp, err := c.httpClient.Do(req)
if err != nil {
return err
}
defer resp.Body.Close() // nolint:errcheck
if resp.StatusCode != http.StatusCreated && resp.StatusCode != http.StatusOK {
msg := fmt.Sprintf("unexpected response: %d %s", resp.StatusCode, resp.Status)
return errors.New(errors.GetKind(resp.StatusCode), msg)
}
return nil
}
func (c *Client) Get(ctx context.Context, key, namespace, scope string) ([]byte, error) {
req, err := http.NewRequestWithContext(ctx, "GET", c.addr+"/v1/cache", nil)
req.Header = http.Header{
"x-cache-key": []string{key},
"x-cache-namespace": []string{namespace},
"x-cache-scope": []string{scope},
}
if err != nil {
return nil, err
}
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close() // nolint:errcheck
if resp.StatusCode != http.StatusOK {
if resp.StatusCode == http.StatusNotFound {
return nil, errors.New(errors.NotFound)
}
msg := fmt.Sprintf("unexpected response: %d %s", resp.StatusCode, resp.Status)
return nil, errors.New(errors.GetKind(resp.StatusCode), msg)
}
return io.ReadAll(resp.Body)
}
package cache
import (
"net/http"
)
type Option func(*Client)
func WithHTTPClient(client *http.Client) Option {
return func(c *Client) {
c.httpClient = client
}
}
// Code generated by counterfeiter. DO NOT EDIT.
package policyfakes
import (
"context"
"sync"
"code.vereign.com/gaiax/tsa/policy/internal/service/policy"
)
type FakeCache struct {
GetStub func(context.Context, string, string, string) ([]byte, error)
getMutex sync.RWMutex
getArgsForCall []struct {
arg1 context.Context
arg2 string
arg3 string
arg4 string
}
getReturns struct {
result1 []byte
result2 error
}
getReturnsOnCall map[int]struct {
result1 []byte
result2 error
}
SetStub func(context.Context, string, string, string, []byte) error
setMutex sync.RWMutex
setArgsForCall []struct {
arg1 context.Context
arg2 string
arg3 string
arg4 string
arg5 []byte
}
setReturns struct {
result1 error
}
setReturnsOnCall map[int]struct {
result1 error
}
invocations map[string][][]interface{}
invocationsMutex sync.RWMutex
}
func (fake *FakeCache) Get(arg1 context.Context, arg2 string, arg3 string, arg4 string) ([]byte, error) {
fake.getMutex.Lock()
ret, specificReturn := fake.getReturnsOnCall[len(fake.getArgsForCall)]
fake.getArgsForCall = append(fake.getArgsForCall, struct {
arg1 context.Context
arg2 string
arg3 string
arg4 string
}{arg1, arg2, arg3, arg4})
stub := fake.GetStub
fakeReturns := fake.getReturns
fake.recordInvocation("Get", []interface{}{arg1, arg2, arg3, arg4})
fake.getMutex.Unlock()
if stub != nil {
return stub(arg1, arg2, arg3, arg4)
}
if specificReturn {
return ret.result1, ret.result2
}
return fakeReturns.result1, fakeReturns.result2
}
func (fake *FakeCache) GetCallCount() int {
fake.getMutex.RLock()
defer fake.getMutex.RUnlock()
return len(fake.getArgsForCall)
}
func (fake *FakeCache) GetCalls(stub func(context.Context, string, string, string) ([]byte, error)) {
fake.getMutex.Lock()
defer fake.getMutex.Unlock()
fake.GetStub = stub
}
func (fake *FakeCache) GetArgsForCall(i int) (context.Context, string, string, string) {
fake.getMutex.RLock()
defer fake.getMutex.RUnlock()
argsForCall := fake.getArgsForCall[i]
return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3, argsForCall.arg4
}
func (fake *FakeCache) GetReturns(result1 []byte, result2 error) {
fake.getMutex.Lock()
defer fake.getMutex.Unlock()
fake.GetStub = nil
fake.getReturns = struct {
result1 []byte
result2 error
}{result1, result2}
}
func (fake *FakeCache) GetReturnsOnCall(i int, result1 []byte, result2 error) {
fake.getMutex.Lock()
defer fake.getMutex.Unlock()
fake.GetStub = nil
if fake.getReturnsOnCall == nil {
fake.getReturnsOnCall = make(map[int]struct {
result1 []byte
result2 error
})
}
fake.getReturnsOnCall[i] = struct {
result1 []byte
result2 error
}{result1, result2}
}
func (fake *FakeCache) Set(arg1 context.Context, arg2 string, arg3 string, arg4 string, arg5 []byte) error {
var arg5Copy []byte
if arg5 != nil {
arg5Copy = make([]byte, len(arg5))
copy(arg5Copy, arg5)
}
fake.setMutex.Lock()
ret, specificReturn := fake.setReturnsOnCall[len(fake.setArgsForCall)]
fake.setArgsForCall = append(fake.setArgsForCall, struct {
arg1 context.Context
arg2 string
arg3 string
arg4 string
arg5 []byte
}{arg1, arg2, arg3, arg4, arg5Copy})
stub := fake.SetStub
fakeReturns := fake.setReturns
fake.recordInvocation("Set", []interface{}{arg1, arg2, arg3, arg4, arg5Copy})
fake.setMutex.Unlock()
if stub != nil {
return stub(arg1, arg2, arg3, arg4, arg5)
}
if specificReturn {
return ret.result1
}
return fakeReturns.result1
}
func (fake *FakeCache) SetCallCount() int {
fake.setMutex.RLock()
defer fake.setMutex.RUnlock()
return len(fake.setArgsForCall)
}
func (fake *FakeCache) SetCalls(stub func(context.Context, string, string, string, []byte) error) {
fake.setMutex.Lock()
defer fake.setMutex.Unlock()
fake.SetStub = stub
}
func (fake *FakeCache) SetArgsForCall(i int) (context.Context, string, string, string, []byte) {
fake.setMutex.RLock()
defer fake.setMutex.RUnlock()
argsForCall := fake.setArgsForCall[i]
return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3, argsForCall.arg4, argsForCall.arg5
}
func (fake *FakeCache) SetReturns(result1 error) {
fake.setMutex.Lock()
defer fake.setMutex.Unlock()
fake.SetStub = nil
fake.setReturns = struct {
result1 error
}{result1}
}
func (fake *FakeCache) SetReturnsOnCall(i int, result1 error) {
fake.setMutex.Lock()
defer fake.setMutex.Unlock()
fake.SetStub = nil
if fake.setReturnsOnCall == nil {
fake.setReturnsOnCall = make(map[int]struct {
result1 error
})
}
fake.setReturnsOnCall[i] = struct {
result1 error
}{result1}
}
func (fake *FakeCache) Invocations() map[string][][]interface{} {
fake.invocationsMutex.RLock()
defer fake.invocationsMutex.RUnlock()
fake.getMutex.RLock()
defer fake.getMutex.RUnlock()
fake.setMutex.RLock()
defer fake.setMutex.RUnlock()
copiedInvocations := map[string][][]interface{}{}
for key, value := range fake.invocations {
copiedInvocations[key] = value
}
return copiedInvocations
}
func (fake *FakeCache) recordInvocation(key string, args []interface{}) {
fake.invocationsMutex.Lock()
defer fake.invocationsMutex.Unlock()
if fake.invocations == nil {
fake.invocations = map[string][][]interface{}{}
}
if fake.invocations[key] == nil {
fake.invocations[key] = [][]interface{}{}
}
fake.invocations[key] = append(fake.invocations[key], args)
}
var _ policy.Cache = new(FakeCache)
...@@ -2,8 +2,10 @@ package policy ...@@ -2,8 +2,10 @@ package policy
import ( import (
"context" "context"
"encoding/json"
"fmt" "fmt"
"github.com/google/uuid"
"github.com/open-policy-agent/opa/rego" "github.com/open-policy-agent/opa/rego"
"go.uber.org/zap" "go.uber.org/zap"
...@@ -13,9 +15,15 @@ import ( ...@@ -13,9 +15,15 @@ import (
"code.vereign.com/gaiax/tsa/policy/internal/storage" "code.vereign.com/gaiax/tsa/policy/internal/storage"
) )
//go:generate counterfeiter . Cache
//go:generate counterfeiter . Storage //go:generate counterfeiter . Storage
//go:generate counterfeiter . RegoCache //go:generate counterfeiter . RegoCache
type Cache interface {
Set(ctx context.Context, key, namespace, scope string, value []byte) error
Get(ctx context.Context, key, namespace, scope string) ([]byte, error)
}
type Storage interface { type Storage interface {
Policy(ctx context.Context, group, name, version string) (*storage.Policy, error) Policy(ctx context.Context, group, name, version string) (*storage.Policy, error)
SetPolicyLock(ctx context.Context, group, name, version string, lock bool) error SetPolicyLock(ctx context.Context, group, name, version string, lock bool) error
...@@ -29,13 +37,15 @@ type RegoCache interface { ...@@ -29,13 +37,15 @@ type RegoCache interface {
type Service struct { type Service struct {
storage Storage storage Storage
queryCache RegoCache queryCache RegoCache
cache Cache
logger *zap.Logger logger *zap.Logger
} }
func New(storage Storage, queryCache RegoCache, logger *zap.Logger) *Service { func New(storage Storage, queryCache RegoCache, cache Cache, logger *zap.Logger) *Service {
return &Service{ return &Service{
storage: storage, storage: storage,
queryCache: queryCache, queryCache: queryCache,
cache: cache,
logger: logger, logger: logger,
} }
} }
...@@ -47,12 +57,14 @@ func New(storage Storage, queryCache RegoCache, logger *zap.Logger) *Service { ...@@ -47,12 +57,14 @@ func New(storage Storage, queryCache RegoCache, logger *zap.Logger) *Service {
// be exactly the same as 'group.policy'. For example: // be exactly the same as 'group.policy'. For example:
// Evaluating the URL: `.../policies/mygroup/example/1.0/evaluation` will // Evaluating the URL: `.../policies/mygroup/example/1.0/evaluation` will
// return results correctly, only if the package declaration inside the policy is: // return results correctly, only if the package declaration inside the policy is:
// `package mygroup.example` // `package mygroup.example`.
func (s *Service) Evaluate(ctx context.Context, req *policy.EvaluateRequest) (interface{}, error) { func (s *Service) Evaluate(ctx context.Context, req *policy.EvaluateRequest) (interface{}, error) {
evaluationID := uuid.NewString()
logger := s.logger.With( logger := s.logger.With(
zap.String("group", req.Group), zap.String("group", req.Group),
zap.String("name", req.PolicyName), zap.String("name", req.PolicyName),
zap.String("version", req.Version), zap.String("version", req.Version),
zap.String("evaluationID", evaluationID),
) )
query, err := s.prepareQuery(ctx, req.Group, req.PolicyName, req.Version) query, err := s.prepareQuery(ctx, req.Group, req.PolicyName, req.Version)
...@@ -77,7 +89,23 @@ func (s *Service) Evaluate(ctx context.Context, req *policy.EvaluateRequest) (in ...@@ -77,7 +89,23 @@ func (s *Service) Evaluate(ctx context.Context, req *policy.EvaluateRequest) (in
return nil, errors.New("policy evaluation result expressions are empty") return nil, errors.New("policy evaluation result expressions are empty")
} }
return resultSet[0].Expressions[0].Value, nil jsonValue, err := json.Marshal(resultSet[0].Expressions[0].Value)
if err != nil {
logger.Error("error encoding result to json", zap.Error(err))
return nil, errors.New("error encoding result to json")
}
if err := s.cache.Set(ctx, evaluationID, "", "", jsonValue); err != nil {
logger.Error("error storing policy result in cache", zap.Error(err))
return nil, errors.New("error storing policy result in cache")
}
result := map[string]interface{}{
"evaluationID": evaluationID,
"result": resultSet[0].Expressions[0].Value,
}
return result, nil
} }
// Lock a policy so that it cannot be evaluated. // Lock a policy so that it cannot be evaluated.
......
...@@ -17,9 +17,7 @@ import ( ...@@ -17,9 +17,7 @@ import (
) )
func TestNew(t *testing.T) { func TestNew(t *testing.T) {
storage := &policyfakes.FakeStorage{} svc := policy.New(nil, nil, nil, zap.NewNop())
regocache := &policyfakes.FakeRegoCache{}
svc := policy.New(storage, regocache, zap.NewNop())
assert.Implements(t, (*goapolicy.Service)(nil), svc) assert.Implements(t, (*goapolicy.Service)(nil), svc)
} }
...@@ -50,7 +48,7 @@ func TestService_Evaluate(t *testing.T) { ...@@ -50,7 +48,7 @@ func TestService_Evaluate(t *testing.T) {
req *goapolicy.EvaluateRequest req *goapolicy.EvaluateRequest
storage policy.Storage storage policy.Storage
regocache policy.RegoCache regocache policy.RegoCache
cache policy.Cache
// expected result // expected result
res interface{} res interface{}
errkind errors.Kind errkind errors.Kind
...@@ -65,6 +63,11 @@ func TestService_Evaluate(t *testing.T) { ...@@ -65,6 +63,11 @@ func TestService_Evaluate(t *testing.T) {
return &q, true return &q, true
}, },
}, },
cache: &policyfakes.FakeCache{
SetStub: func(ctx context.Context, s string, s2 string, s3 string, bytes []byte) error {
return nil
},
},
res: map[string]interface{}{"allow": true}, res: map[string]interface{}{"allow": true},
}, },
{ {
...@@ -138,17 +141,55 @@ func TestService_Evaluate(t *testing.T) { ...@@ -138,17 +141,55 @@ func TestService_Evaluate(t *testing.T) {
}, nil }, nil
}, },
}, },
cache: &policyfakes.FakeCache{
SetStub: func(ctx context.Context, s string, s2 string, s3 string, bytes []byte) error {
return nil
},
},
res: map[string]interface{}{"allow": true}, res: map[string]interface{}{"allow": true},
}, },
{
name: "policy is executed successfully, but storing the result in cache fails",
req: testReq(),
regocache: &policyfakes.FakeRegoCache{
GetStub: func(key string) (*rego.PreparedEvalQuery, bool) {
return nil, false
},
},
storage: &policyfakes.FakeStorage{
PolicyStub: func(ctx context.Context, s string, s2 string, s3 string) (*storage.Policy, error) {
return &storage.Policy{
Name: "example",
Group: "testgroup",
Version: "1.0",
Rego: testPolicy,
Locked: false,
LastUpdate: time.Now(),
}, nil
},
},
cache: &policyfakes.FakeCache{
SetStub: func(ctx context.Context, s string, s2 string, s3 string, bytes []byte) error {
return errors.New("some error")
},
},
errkind: errors.Unknown,
errtext: "error storing policy result in cache",
},
} }
for _, test := range tests { for _, test := range tests {
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
svc := policy.New(test.storage, test.regocache, zap.NewNop()) svc := policy.New(test.storage, test.regocache, test.cache, zap.NewNop())
res, err := svc.Evaluate(context.Background(), test.req) res, err := svc.Evaluate(context.Background(), test.req)
if err == nil { if err == nil {
assert.Empty(t, test.errtext) assert.Empty(t, test.errtext)
assert.Equal(t, test.res, res) assert.NotNil(t, res)
result, ok := res.(map[string]interface{})
assert.True(t, ok)
assert.Equal(t, test.res, result["result"])
assert.NotEmpty(t, result["evaluationID"])
} else { } else {
e, ok := err.(*errors.Error) e, ok := err.(*errors.Error)
assert.True(t, ok) assert.True(t, ok)
...@@ -243,7 +284,7 @@ func TestService_Lock(t *testing.T) { ...@@ -243,7 +284,7 @@ func TestService_Lock(t *testing.T) {
for _, test := range tests { for _, test := range tests {
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
svc := policy.New(test.storage, nil, zap.NewNop()) svc := policy.New(test.storage, nil, nil, zap.NewNop())
err := svc.Lock(context.Background(), test.req) err := svc.Lock(context.Background(), test.req)
if err == nil { if err == nil {
assert.Empty(t, test.errtext) assert.Empty(t, test.errtext)
...@@ -340,7 +381,7 @@ func TestService_Unlock(t *testing.T) { ...@@ -340,7 +381,7 @@ func TestService_Unlock(t *testing.T) {
for _, test := range tests { for _, test := range tests {
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
svc := policy.New(test.storage, nil, zap.NewNop()) svc := policy.New(test.storage, nil, nil, zap.NewNop())
err := svc.Unlock(context.Background(), test.req) err := svc.Unlock(context.Background(), test.req)
if err == nil { if err == nil {
assert.Empty(t, test.errtext) assert.Empty(t, test.errtext)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment