diff --git a/internal/regocache/regocache_test.go b/internal/regocache/regocache_test.go new file mode 100644 index 0000000000000000000000000000000000000000..fd63e6f1b76f2ecf9e43bf596ebfaa0d8699fb87 --- /dev/null +++ b/internal/regocache/regocache_test.go @@ -0,0 +1,72 @@ +package regocache_test + +import ( + "context" + "testing" + + "github.com/open-policy-agent/opa/rego" + "github.com/stretchr/testify/assert" + + "code.vereign.com/gaiax/tsa/policy/internal/regocache" + "code.vereign.com/gaiax/tsa/policy/internal/service/policy" +) + +const regoPolicy = ` + package test + + allow { + input.val == 1 + } +` + +func TestNew(t *testing.T) { + cache := regocache.New() + assert.Implements(t, (*policy.RegoCache)(nil), cache) +} + +func TestCache_SetAndGet(t *testing.T) { + q1, err := rego.New( + rego.Module("filename.rego", regoPolicy), + rego.Query("data"), + ).PrepareForEval(context.Background()) + assert.NoError(t, err) + + cache := regocache.New() + cache.Set("query1", &q1) + + q2, ok := cache.Get("query1") + assert.True(t, ok) + assert.Equal(t, q1, *q2) +} + +func TestCache_Purge(t *testing.T) { + q1, err := rego.New( + rego.Module("filename.rego", regoPolicy), + rego.Query("data"), + ).PrepareForEval(context.Background()) + assert.NoError(t, err) + + cache := regocache.New() + cache.Set("query1", &q1) + + cache.Purge() + q2, ok := cache.Get("query1") + assert.False(t, ok) + assert.Nil(t, q2) +} + +func TestCache_PolicyDataChange(t *testing.T) { + q1, err := rego.New( + rego.Module("filename.rego", regoPolicy), + rego.Query("data"), + ).PrepareForEval(context.Background()) + assert.NoError(t, err) + + cache := regocache.New() + cache.Set("query1", &q1) + + cache.PolicyDataChange() + q2, ok := cache.Get("query1") + assert.False(t, ok) + assert.Nil(t, q2) +} diff --git a/internal/service/policy/policyfakes/fake_rego_cache.go b/internal/service/policy/policyfakes/fake_rego_cache.go new file mode 100644 index 0000000000000000000000000000000000000000..4bc89b3dc24a07e9c32b4302f6cc2ea189665680 --- /dev/null +++ b/internal/service/policy/policyfakes/fake_rego_cache.go @@ -0,0 +1,158 @@ +// Code generated by counterfeiter. DO NOT EDIT. +package policyfakes + +import ( + "sync" + + "code.vereign.com/gaiax/tsa/policy/internal/service/policy" + "github.com/open-policy-agent/opa/rego" +) + +type FakeRegoCache struct { + GetStub func(string) (*rego.PreparedEvalQuery, bool) + getMutex sync.RWMutex + getArgsForCall []struct { + arg1 string + } + getReturns struct { + result1 *rego.PreparedEvalQuery + result2 bool + } + getReturnsOnCall map[int]struct { + result1 *rego.PreparedEvalQuery + result2 bool + } + SetStub func(string, *rego.PreparedEvalQuery) + setMutex sync.RWMutex + setArgsForCall []struct { + arg1 string + arg2 *rego.PreparedEvalQuery + } + invocations map[string][][]interface{} + invocationsMutex sync.RWMutex +} + +func (fake *FakeRegoCache) Get(arg1 string) (*rego.PreparedEvalQuery, bool) { + fake.getMutex.Lock() + ret, specificReturn := fake.getReturnsOnCall[len(fake.getArgsForCall)] + fake.getArgsForCall = append(fake.getArgsForCall, struct { + arg1 string + }{arg1}) + stub := fake.GetStub + fakeReturns := fake.getReturns + fake.recordInvocation("Get", []interface{}{arg1}) + fake.getMutex.Unlock() + if stub != nil { + return stub(arg1) + } + if specificReturn { + return ret.result1, ret.result2 + } + return fakeReturns.result1, fakeReturns.result2 +} + +func (fake *FakeRegoCache) GetCallCount() int { + fake.getMutex.RLock() + defer fake.getMutex.RUnlock() + return len(fake.getArgsForCall) +} + +func (fake *FakeRegoCache) GetCalls(stub func(string) (*rego.PreparedEvalQuery, bool)) { + fake.getMutex.Lock() + defer fake.getMutex.Unlock() + fake.GetStub = stub +} + +func (fake *FakeRegoCache) GetArgsForCall(i int) string { + fake.getMutex.RLock() + defer fake.getMutex.RUnlock() + argsForCall := fake.getArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeRegoCache) GetReturns(result1 *rego.PreparedEvalQuery, result2 bool) { + fake.getMutex.Lock() + defer fake.getMutex.Unlock() + fake.GetStub = nil + fake.getReturns = struct { + result1 *rego.PreparedEvalQuery + result2 bool + }{result1, result2} +} + +func (fake *FakeRegoCache) GetReturnsOnCall(i int, result1 *rego.PreparedEvalQuery, result2 bool) { + fake.getMutex.Lock() + defer fake.getMutex.Unlock() + fake.GetStub = nil + if fake.getReturnsOnCall == nil { + fake.getReturnsOnCall = make(map[int]struct { + result1 *rego.PreparedEvalQuery + result2 bool + }) + } + fake.getReturnsOnCall[i] = struct { + result1 *rego.PreparedEvalQuery + result2 bool + }{result1, result2} +} + +func (fake *FakeRegoCache) Set(arg1 string, arg2 *rego.PreparedEvalQuery) { + fake.setMutex.Lock() + fake.setArgsForCall = append(fake.setArgsForCall, struct { + arg1 string + arg2 *rego.PreparedEvalQuery + }{arg1, arg2}) + stub := fake.SetStub + fake.recordInvocation("Set", []interface{}{arg1, arg2}) + fake.setMutex.Unlock() + if stub != nil { + fake.SetStub(arg1, arg2) + } +} + +func (fake *FakeRegoCache) SetCallCount() int { + fake.setMutex.RLock() + defer fake.setMutex.RUnlock() + return len(fake.setArgsForCall) +} + +func (fake *FakeRegoCache) SetCalls(stub func(string, *rego.PreparedEvalQuery)) { + fake.setMutex.Lock() + defer fake.setMutex.Unlock() + fake.SetStub = stub +} + +func (fake *FakeRegoCache) SetArgsForCall(i int) (string, *rego.PreparedEvalQuery) { + fake.setMutex.RLock() + defer fake.setMutex.RUnlock() + argsForCall := fake.setArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeRegoCache) Invocations() map[string][][]interface{} { + fake.invocationsMutex.RLock() + defer fake.invocationsMutex.RUnlock() + fake.getMutex.RLock() + defer fake.getMutex.RUnlock() + fake.setMutex.RLock() + defer fake.setMutex.RUnlock() + copiedInvocations := map[string][][]interface{}{} + for key, value := range fake.invocations { + copiedInvocations[key] = value + } + return copiedInvocations +} + +func (fake *FakeRegoCache) recordInvocation(key string, args []interface{}) { + fake.invocationsMutex.Lock() + defer fake.invocationsMutex.Unlock() + if fake.invocations == nil { + fake.invocations = map[string][][]interface{}{} + } + if fake.invocations[key] == nil { + fake.invocations[key] = [][]interface{}{} + } + fake.invocations[key] = append(fake.invocations[key], args) +} + +var _ policy.RegoCache = new(FakeRegoCache) diff --git a/internal/service/policy/policyfakes/fake_storage.go b/internal/service/policy/policyfakes/fake_storage.go new file mode 100644 index 0000000000000000000000000000000000000000..32bf19186b7c957b17f2c04cde9d745ee6e4edbd --- /dev/null +++ b/internal/service/policy/policyfakes/fake_storage.go @@ -0,0 +1,206 @@ +// Code generated by counterfeiter. DO NOT EDIT. +package policyfakes + +import ( + "context" + "sync" + + "code.vereign.com/gaiax/tsa/policy/internal/service/policy" + "code.vereign.com/gaiax/tsa/policy/internal/storage" +) + +type FakeStorage struct { + PolicyStub func(context.Context, string, string, string) (*storage.Policy, error) + policyMutex sync.RWMutex + policyArgsForCall []struct { + arg1 context.Context + arg2 string + arg3 string + arg4 string + } + policyReturns struct { + result1 *storage.Policy + result2 error + } + policyReturnsOnCall map[int]struct { + result1 *storage.Policy + result2 error + } + SetPolicyLockStub func(context.Context, string, string, string, bool) error + setPolicyLockMutex sync.RWMutex + setPolicyLockArgsForCall []struct { + arg1 context.Context + arg2 string + arg3 string + arg4 string + arg5 bool + } + setPolicyLockReturns struct { + result1 error + } + setPolicyLockReturnsOnCall map[int]struct { + result1 error + } + invocations map[string][][]interface{} + invocationsMutex sync.RWMutex +} + +func (fake *FakeStorage) Policy(arg1 context.Context, arg2 string, arg3 string, arg4 string) (*storage.Policy, error) { + fake.policyMutex.Lock() + ret, specificReturn := fake.policyReturnsOnCall[len(fake.policyArgsForCall)] + fake.policyArgsForCall = append(fake.policyArgsForCall, struct { + arg1 context.Context + arg2 string + arg3 string + arg4 string + }{arg1, arg2, arg3, arg4}) + stub := fake.PolicyStub + fakeReturns := fake.policyReturns + fake.recordInvocation("Policy", []interface{}{arg1, arg2, arg3, arg4}) + fake.policyMutex.Unlock() + if stub != nil { + return stub(arg1, arg2, arg3, arg4) + } + if specificReturn { + return ret.result1, ret.result2 + } + return fakeReturns.result1, fakeReturns.result2 +} + +func (fake *FakeStorage) PolicyCallCount() int { + fake.policyMutex.RLock() + defer fake.policyMutex.RUnlock() + return len(fake.policyArgsForCall) +} + +func (fake *FakeStorage) PolicyCalls(stub func(context.Context, string, string, string) (*storage.Policy, error)) { + fake.policyMutex.Lock() + defer fake.policyMutex.Unlock() + fake.PolicyStub = stub +} + +func (fake *FakeStorage) PolicyArgsForCall(i int) (context.Context, string, string, string) { + fake.policyMutex.RLock() + defer fake.policyMutex.RUnlock() + argsForCall := fake.policyArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3, argsForCall.arg4 +} + +func (fake *FakeStorage) PolicyReturns(result1 *storage.Policy, result2 error) { + fake.policyMutex.Lock() + defer fake.policyMutex.Unlock() + fake.PolicyStub = nil + fake.policyReturns = struct { + result1 *storage.Policy + result2 error + }{result1, result2} +} + +func (fake *FakeStorage) PolicyReturnsOnCall(i int, result1 *storage.Policy, result2 error) { + fake.policyMutex.Lock() + defer fake.policyMutex.Unlock() + fake.PolicyStub = nil + if fake.policyReturnsOnCall == nil { + fake.policyReturnsOnCall = make(map[int]struct { + result1 *storage.Policy + result2 error + }) + } + fake.policyReturnsOnCall[i] = struct { + result1 *storage.Policy + result2 error + }{result1, result2} +} + +func (fake *FakeStorage) SetPolicyLock(arg1 context.Context, arg2 string, arg3 string, arg4 string, arg5 bool) error { + fake.setPolicyLockMutex.Lock() + ret, specificReturn := fake.setPolicyLockReturnsOnCall[len(fake.setPolicyLockArgsForCall)] + fake.setPolicyLockArgsForCall = append(fake.setPolicyLockArgsForCall, struct { + arg1 context.Context + arg2 string + arg3 string + arg4 string + arg5 bool + }{arg1, arg2, arg3, arg4, arg5}) + stub := fake.SetPolicyLockStub + fakeReturns := fake.setPolicyLockReturns + fake.recordInvocation("SetPolicyLock", []interface{}{arg1, arg2, arg3, arg4, arg5}) + fake.setPolicyLockMutex.Unlock() + if stub != nil { + return stub(arg1, arg2, arg3, arg4, arg5) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeStorage) SetPolicyLockCallCount() int { + fake.setPolicyLockMutex.RLock() + defer fake.setPolicyLockMutex.RUnlock() + return len(fake.setPolicyLockArgsForCall) +} + +func (fake *FakeStorage) SetPolicyLockCalls(stub func(context.Context, string, string, string, bool) error) { + fake.setPolicyLockMutex.Lock() + defer fake.setPolicyLockMutex.Unlock() + fake.SetPolicyLockStub = stub +} + +func (fake *FakeStorage) SetPolicyLockArgsForCall(i int) (context.Context, string, string, string, bool) { + fake.setPolicyLockMutex.RLock() + defer fake.setPolicyLockMutex.RUnlock() + argsForCall := fake.setPolicyLockArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3, argsForCall.arg4, argsForCall.arg5 +} + +func (fake *FakeStorage) SetPolicyLockReturns(result1 error) { + fake.setPolicyLockMutex.Lock() + defer fake.setPolicyLockMutex.Unlock() + fake.SetPolicyLockStub = nil + fake.setPolicyLockReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeStorage) SetPolicyLockReturnsOnCall(i int, result1 error) { + fake.setPolicyLockMutex.Lock() + defer fake.setPolicyLockMutex.Unlock() + fake.SetPolicyLockStub = nil + if fake.setPolicyLockReturnsOnCall == nil { + fake.setPolicyLockReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.setPolicyLockReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeStorage) Invocations() map[string][][]interface{} { + fake.invocationsMutex.RLock() + defer fake.invocationsMutex.RUnlock() + fake.policyMutex.RLock() + defer fake.policyMutex.RUnlock() + fake.setPolicyLockMutex.RLock() + defer fake.setPolicyLockMutex.RUnlock() + copiedInvocations := map[string][][]interface{}{} + for key, value := range fake.invocations { + copiedInvocations[key] = value + } + return copiedInvocations +} + +func (fake *FakeStorage) recordInvocation(key string, args []interface{}) { + fake.invocationsMutex.Lock() + defer fake.invocationsMutex.Unlock() + if fake.invocations == nil { + fake.invocations = map[string][][]interface{}{} + } + if fake.invocations[key] == nil { + fake.invocations[key] = [][]interface{}{} + } + fake.invocations[key] = append(fake.invocations[key], args) +} + +var _ policy.Storage = new(FakeStorage) diff --git a/internal/service/policy/service.go b/internal/service/policy/service.go index 8c6d19ea2712a8e1af6edcb43fcbfb5de3d8999e..aff12a10b1ba8603a7d799e0cdc28fc317f63cae 100644 --- a/internal/service/policy/service.go +++ b/internal/service/policy/service.go @@ -12,6 +12,9 @@ import ( "code.vereign.com/gaiax/tsa/policy/internal/storage" ) +//go:generate counterfeiter . Storage +//go:generate counterfeiter . RegoCache + type Storage interface { Policy(ctx context.Context, name, group, version string) (*storage.Policy, error) SetPolicyLock(ctx context.Context, name, group, version string, lock bool) error diff --git a/internal/service/policy/service_test.go b/internal/service/policy/service_test.go new file mode 100644 index 0000000000000000000000000000000000000000..5d1841b16c7cd6b554855dd8f076cdc866d7a39a --- /dev/null +++ b/internal/service/policy/service_test.go @@ -0,0 +1,162 @@ +package policy_test + +import ( + "context" + "testing" + "time" + + "github.com/open-policy-agent/opa/rego" + "github.com/stretchr/testify/assert" + "go.uber.org/zap" + + "code.vereign.com/gaiax/tsa/golib/errors" + goapolicy "code.vereign.com/gaiax/tsa/policy/gen/policy" + "code.vereign.com/gaiax/tsa/policy/internal/service/policy" + "code.vereign.com/gaiax/tsa/policy/internal/service/policy/policyfakes" + "code.vereign.com/gaiax/tsa/policy/internal/storage" +) + +func TestNew(t *testing.T) { + storage := &policyfakes.FakeStorage{} + regocache := &policyfakes.FakeRegoCache{} + svc := policy.New(storage, regocache, zap.NewNop()) + assert.Implements(t, (*goapolicy.Service)(nil), svc) +} + +func TestService_Evaluate(t *testing.T) { + // prepare test policy source code that will be evaluated + testPolicy := `package testgroup.example allow { input.msg == "yes" }` + + // prepare test query that can be retrieved from rego cache + testQuery, err := rego.New( + rego.Module("example.rego", testPolicy), + rego.Query("data.testgroup.example"), + ).PrepareForEval(context.Background()) + assert.NoError(t, err) + + // prepare test request to be used in tests + testReq := func() *goapolicy.EvaluateRequest { + return &goapolicy.EvaluateRequest{ + Group: "testgroup", + PolicyName: "example", + Version: "1.0", + Input: map[string]interface{}{"msg": "yes"}, + } + } + + tests := []struct { + // test input + name string + req *goapolicy.EvaluateRequest + storage policy.Storage + regocache policy.RegoCache + + // expected result + res *goapolicy.EvaluateResult + errkind errors.Kind + errtext string + }{ + { + name: "prepared query is found in cache", + req: testReq(), + regocache: &policyfakes.FakeRegoCache{ + GetStub: func(key string) (*rego.PreparedEvalQuery, bool) { + q := testQuery + return &q, true + }, + }, + res: &goapolicy.EvaluateResult{Result: map[string]interface{}{"allow": true}}, + }, + { + name: "policy is not found", + req: testReq(), + regocache: &policyfakes.FakeRegoCache{ + GetStub: func(key string) (*rego.PreparedEvalQuery, bool) { + return nil, false + }, + }, + storage: &policyfakes.FakeStorage{ + PolicyStub: func(ctx context.Context, s string, s2 string, s3 string) (*storage.Policy, error) { + return nil, errors.New(errors.NotFound) + }, + }, + res: nil, + errkind: errors.NotFound, + errtext: "not found", + }, + { + name: "error getting policy from storage", + req: testReq(), + regocache: &policyfakes.FakeRegoCache{ + GetStub: func(key string) (*rego.PreparedEvalQuery, bool) { + return nil, false + }, + }, + storage: &policyfakes.FakeStorage{ + PolicyStub: func(ctx context.Context, s string, s2 string, s3 string) (*storage.Policy, error) { + return nil, errors.New("some error") + }, + }, + res: nil, + errkind: errors.Unknown, + errtext: "some error", + }, + { + name: "policy is locked", + req: testReq(), + regocache: &policyfakes.FakeRegoCache{ + GetStub: func(key string) (*rego.PreparedEvalQuery, bool) { + return nil, false + }, + }, + storage: &policyfakes.FakeStorage{ + PolicyStub: func(ctx context.Context, s string, s2 string, s3 string) (*storage.Policy, error) { + return &storage.Policy{Locked: true}, nil + }, + }, + res: nil, + errkind: errors.Forbidden, + errtext: "policy is locked", + }, + { + name: "policy is found in storage and isn't locked", + req: testReq(), + regocache: &policyfakes.FakeRegoCache{ + GetStub: func(key string) (*rego.PreparedEvalQuery, bool) { + return nil, false + }, + }, + storage: &policyfakes.FakeStorage{ + PolicyStub: func(ctx context.Context, s string, s2 string, s3 string) (*storage.Policy, error) { + return &storage.Policy{ + Name: "example", + Group: "testgroup", + Version: "1.0", + Rego: testPolicy, + Locked: false, + LastUpdate: time.Now(), + }, nil + }, + }, + res: &goapolicy.EvaluateResult{Result: map[string]interface{}{"allow": true}}, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + svc := policy.New(test.storage, test.regocache, zap.NewNop()) + res, err := svc.Evaluate(context.Background(), test.req) + if err == nil { + assert.Empty(t, test.errtext) + assert.Equal(t, test.res, res) + } else { + e, ok := err.(*errors.Error) + assert.True(t, ok) + + assert.Contains(t, e.Error(), test.errtext) + assert.Equal(t, test.errkind, e.Kind) + assert.Equal(t, test.res, res) + } + }) + } +}