diff --git a/cmd/policy/main.go b/cmd/policy/main.go index 89ac306184072b2b50784b073c038a37fc9bfac4..ab7ecd306b4df19086d3619a4c867f8eb4118a5d 100644 --- a/cmd/policy/main.go +++ b/cmd/policy/main.go @@ -9,6 +9,7 @@ import ( "time" "github.com/kelseyhightower/envconfig" + "github.com/open-policy-agent/opa/rego" "go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo/options" "go.uber.org/zap" @@ -70,12 +71,13 @@ func main() { // create rego query cache regocache := regocache.New() - // custom rego functions - regofuncs := regofunc.New( - cfg.Cache.Addr, - regofunc.WithHTTPClient(httpClient()), - regofunc.WithLogger(logger), - ) + // register rego extension functions + { + cacheFuncs := regofunc.NewCacheFuncs(cfg.Cache.Addr, httpClient(), logger) + regofunc.Register("cacheGet", rego.Function3(cacheFuncs.CacheGetFunc())) + regofunc.Register("cacheSet", rego.Function4(cacheFuncs.CacheSetFunc())) + regofunc.Register("strictBuiltinErrors", rego.StrictBuiltinErrors(true)) + } // subscribe the cache for policy data changes storage.AddPolicyChangeSubscriber(regocache) @@ -86,7 +88,7 @@ func main() { healthSvc goahealth.Service ) { - policySvc = policy.New(storage, regocache, regofuncs, logger) + policySvc = policy.New(storage, regocache, logger) healthSvc = health.New() } diff --git a/internal/regofunc/regofunc.go b/internal/regofunc/cache.go similarity index 77% rename from internal/regofunc/regofunc.go rename to internal/regofunc/cache.go index c7977c4169b57aa8dd0e0aa1380d85ed6d13b6e0..8f0b2821f5fe8dc4849aa1d579bcd3382b2c94a5 100644 --- a/internal/regofunc/regofunc.go +++ b/internal/regofunc/cache.go @@ -1,6 +1,3 @@ -// Package regofunc provides functions that extend the Rego runtime -// with additional capabilities and built-in functions which can be -// used when writing and evaluating Rego polices. package regofunc import ( @@ -15,28 +12,21 @@ import ( "go.uber.org/zap" ) -type RegoFunc struct { - cacheAddr string - +type CacheFuncs struct { + cacheAddr string httpClient *http.Client logger *zap.Logger } -func New(cacheAddr string, opts ...Option) *RegoFunc { - rf := &RegoFunc{ +func NewCacheFuncs(cacheAddr string, httpClient *http.Client, logger *zap.Logger) *CacheFuncs { + return &CacheFuncs{ cacheAddr: cacheAddr, - httpClient: http.DefaultClient, - logger: zap.NewNop(), - } - - for _, opt := range opts { - opt(rf) + httpClient: httpClient, + logger: logger, } - - return rf } -func (r *RegoFunc) CacheGetFunc() (*rego.Function, rego.Builtin3) { +func (cf *CacheFuncs) CacheGetFunc() (*rego.Function, rego.Builtin3) { return ®o.Function{ Name: "cache.get", Decl: types.NewFunction(types.Args(types.S, types.S, types.S), types.A), @@ -53,7 +43,7 @@ func (r *RegoFunc) CacheGetFunc() (*rego.Function, rego.Builtin3) { return nil, fmt.Errorf("invalid scope: %s", err) } - req, err := http.NewRequest("GET", r.cacheAddr+"/v1/cache", nil) + req, err := http.NewRequest("GET", cf.cacheAddr+"/v1/cache", nil) req.Header = http.Header{ "x-cache-key": []string{key}, "x-cache-namespace": []string{namespace}, @@ -63,7 +53,7 @@ func (r *RegoFunc) CacheGetFunc() (*rego.Function, rego.Builtin3) { return nil, err } - resp, err := r.httpClient.Do(req.WithContext(bctx.Context)) + resp, err := cf.httpClient.Do(req.WithContext(bctx.Context)) if err != nil { return nil, err } @@ -82,7 +72,7 @@ func (r *RegoFunc) CacheGetFunc() (*rego.Function, rego.Builtin3) { } } -func (r *RegoFunc) CacheSetFunc() (*rego.Function, rego.Builtin4) { +func (cf *CacheFuncs) CacheSetFunc() (*rego.Function, rego.Builtin4) { return ®o.Function{ Name: "cache.set", Decl: types.NewFunction(types.Args(types.S, types.S, types.S, types.S), types.A), @@ -107,7 +97,7 @@ func (r *RegoFunc) CacheSetFunc() (*rego.Function, rego.Builtin4) { return nil, err } - req, err := http.NewRequest("POST", r.cacheAddr+"/v1/cache", bytes.NewReader(jsonData)) + req, err := http.NewRequest("POST", cf.cacheAddr+"/v1/cache", bytes.NewReader(jsonData)) if err != nil { return nil, err } @@ -118,7 +108,7 @@ func (r *RegoFunc) CacheSetFunc() (*rego.Function, rego.Builtin4) { "x-cache-scope": []string{scope}, } - resp, err := r.httpClient.Do(req.WithContext(bctx.Context)) + resp, err := cf.httpClient.Do(req.WithContext(bctx.Context)) if err != nil { return nil, err } diff --git a/internal/regofunc/regofunc_test.go b/internal/regofunc/cache_test.go similarity index 80% rename from internal/regofunc/regofunc_test.go rename to internal/regofunc/cache_test.go index 86e48dc96c900a401bb449c573604ada9446f066..150f5c4cd21136f14aa6f696556fbda9c899ed9a 100644 --- a/internal/regofunc/regofunc_test.go +++ b/internal/regofunc/cache_test.go @@ -11,22 +11,23 @@ import ( "github.com/open-policy-agent/opa/rego" "github.com/stretchr/testify/assert" + "go.uber.org/zap" "code.vereign.com/gaiax/tsa/policy/internal/regofunc" ) -func TestRegoFunc_CacheGetFunc(t *testing.T) { +func TestCacheGetFunc(t *testing.T) { expected := `{"taskID":"deadbeef"}` cacheSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { _, _ = fmt.Fprint(w, expected) })) defer cacheSrv.Close() - regofuncs := regofunc.New(cacheSrv.URL) + cacheFuncs := regofunc.NewCacheFuncs(cacheSrv.URL, http.DefaultClient, zap.NewNop()) r := rego.New( rego.Query(`cache.get("open-policy-agent", "opa", "111")`), - rego.Function3(regofuncs.CacheGetFunc()), + rego.Function3(cacheFuncs.CacheGetFunc()), ) resultSet, err := r.Eval(context.Background()) assert.NoError(t, err) @@ -36,7 +37,7 @@ func TestRegoFunc_CacheGetFunc(t *testing.T) { assert.Equal(t, expected, string(resultBytes)) } -func TestRegoFunc_CacheSetFuncSuccess(t *testing.T) { +func TestCacheSetFuncSuccess(t *testing.T) { expected := "success" cacheSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { expectedRequestBody := `{"test":123}` @@ -52,12 +53,12 @@ func TestRegoFunc_CacheSetFuncSuccess(t *testing.T) { })) defer cacheSrv.Close() - regofuncs := regofunc.New(cacheSrv.URL) + cacheFuncs := regofunc.NewCacheFuncs(cacheSrv.URL, http.DefaultClient, zap.NewNop()) input := map[string]interface{}{"test": 123} query, err := rego.New( rego.Query(`cache.set("open-policy-agent", "opa", "111", input)`), - rego.Function4(regofuncs.CacheSetFunc()), + rego.Function4(cacheFuncs.CacheSetFunc()), ).PrepareForEval(context.Background()) assert.NoError(t, err) @@ -68,7 +69,7 @@ func TestRegoFunc_CacheSetFuncSuccess(t *testing.T) { assert.Equal(t, expected, resultSet[0].Expressions[0].Value) } -func TestRegoFunc_CacheSetFuncError(t *testing.T) { +func TestCacheSetFuncError(t *testing.T) { cacheSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { expectedRequestBody := "test" bodyBytes, err := io.ReadAll(r.Body) @@ -81,11 +82,11 @@ func TestRegoFunc_CacheSetFuncError(t *testing.T) { })) defer cacheSrv.Close() - regofuncs := regofunc.New(cacheSrv.URL) + cacheFuncs := regofunc.NewCacheFuncs(cacheSrv.URL, http.DefaultClient, zap.NewNop()) r := rego.New( rego.Query(`cache.set("open-policy-agent", "opa", "111", "test")`), - rego.Function4(regofuncs.CacheSetFunc()), + rego.Function4(cacheFuncs.CacheSetFunc()), ) resultSet, err := r.Eval(context.Background()) diff --git a/internal/regofunc/doc.go b/internal/regofunc/doc.go new file mode 100644 index 0000000000000000000000000000000000000000..47bd85d7653a9d303a5eadf58b7337a35b1322d3 --- /dev/null +++ b/internal/regofunc/doc.go @@ -0,0 +1,4 @@ +// Package regofunc provides functions that extend the Rego runtime +// with additional capabilities and built-in functions which can be +// used when writing and evaluating Rego polices. +package regofunc diff --git a/internal/regofunc/option.go b/internal/regofunc/option.go deleted file mode 100644 index 0e1e41cd66cbc89cb5481abc37217a65b76cd3f5..0000000000000000000000000000000000000000 --- a/internal/regofunc/option.go +++ /dev/null @@ -1,21 +0,0 @@ -package regofunc - -import ( - "net/http" - - "go.uber.org/zap" -) - -type Option func(*RegoFunc) - -func WithHTTPClient(client *http.Client) Option { - return func(r *RegoFunc) { - r.httpClient = client - } -} - -func WithLogger(logger *zap.Logger) Option { - return func(c *RegoFunc) { - c.logger = logger - } -} diff --git a/internal/regofunc/registry.go b/internal/regofunc/registry.go new file mode 100644 index 0000000000000000000000000000000000000000..7ff98dd12873a0d759a2f4be3482598a64c4a02f --- /dev/null +++ b/internal/regofunc/registry.go @@ -0,0 +1,37 @@ +package regofunc + +import ( + "fmt" + "sync" + + "github.com/open-policy-agent/opa/rego" +) + +type RegoFunc func(*rego.Rego) + +var ( + muRegistry sync.RWMutex + regoFuncRegistry = make(map[string]RegoFunc) +) + +// Register an extension function. +func Register(name string, fn RegoFunc) { + if fn == nil { + panic(fmt.Errorf("cannot register nil Rego function: %s", name)) + } + + if _, registered := regoFuncRegistry[name]; !registered { + regoFuncRegistry[name] = fn + } +} + +// List returns all registered extension functions. +func List() []RegoFunc { + list := make([]RegoFunc, 0) + muRegistry.RLock() + for _, fn := range regoFuncRegistry { + list = append(list, fn) + } + muRegistry.RUnlock() + return list +} diff --git a/internal/regofunc/registry_test.go b/internal/regofunc/registry_test.go new file mode 100644 index 0000000000000000000000000000000000000000..dad462175a3b35c376ce5ac05203c2f1129e1114 --- /dev/null +++ b/internal/regofunc/registry_test.go @@ -0,0 +1,24 @@ +package regofunc_test + +import ( + "net/http" + "testing" + + "github.com/open-policy-agent/opa/rego" + "github.com/stretchr/testify/assert" + "go.uber.org/zap" + + "code.vereign.com/gaiax/tsa/policy/internal/regofunc" +) + +func TestList(t *testing.T) { + funcs := regofunc.List() + assert.Len(t, funcs, 0) + + cacheFuncs := regofunc.NewCacheFuncs("localhost:8080", http.DefaultClient, zap.NewNop()) + regofunc.Register("cacheGet", rego.Function3(cacheFuncs.CacheGetFunc())) + regofunc.Register("cacheSet", rego.Function3(cacheFuncs.CacheGetFunc())) + + funcs = regofunc.List() + assert.Len(t, funcs, 2) +} diff --git a/internal/service/policy/service.go b/internal/service/policy/service.go index 9a1bfc7c215798c05ee5753109e7da890c9b0d1a..89bec03d4b1dacb6bdb8b614a800f0a6483bbbed 100644 --- a/internal/service/policy/service.go +++ b/internal/service/policy/service.go @@ -27,18 +27,16 @@ type RegoCache interface { } type Service struct { - storage Storage - cache RegoCache - regoFunc *regofunc.RegoFunc - logger *zap.Logger + storage Storage + queryCache RegoCache + logger *zap.Logger } -func New(storage Storage, cache RegoCache, regoFunc *regofunc.RegoFunc, logger *zap.Logger) *Service { +func New(storage Storage, queryCache RegoCache, logger *zap.Logger) *Service { return &Service{ - storage: storage, - cache: cache, - regoFunc: regoFunc, - logger: logger, + storage: storage, + queryCache: queryCache, + logger: logger, } } @@ -145,11 +143,11 @@ func (s *Service) Unlock(ctx context.Context, req *policy.UnlockRequest) error { } // prepareQuery tries to get a prepared query from the regocache. -// If the cache entry is not found, it will try to prepare a new -// query and will set it into the cache for future use. +// If the queryCache entry is not found, it will try to prepare a new +// query and will set it into the queryCache for future use. func (s *Service) prepareQuery(ctx context.Context, policyName, group, version string) (*rego.PreparedEvalQuery, error) { key := s.queryCacheKey(policyName, group, version) - query, ok := s.cache.Get(key) + query, ok := s.queryCache.Get(key) if ok { return query, nil } @@ -173,21 +171,28 @@ func (s *Service) prepareQuery(ctx context.Context, policyName, group, version s regoQuery := fmt.Sprintf("data.%s.%s", group, policyName) newQuery, err := rego.New( - rego.Module(pol.Filename, pol.Rego), - rego.Query(regoQuery), - rego.Function3(s.regoFunc.CacheGetFunc()), - rego.Function4(s.regoFunc.CacheSetFunc()), - rego.StrictBuiltinErrors(true), + buildRegoArgs(pol.Filename, pol.Rego, regoQuery)..., ).PrepareForEval(ctx) if err != nil { return nil, errors.New("error preparing rego query", err) } - s.cache.Set(key, &newQuery) + s.queryCache.Set(key, &newQuery) return &newQuery, nil } +func buildRegoArgs(filename, regoPolicy, regoQuery string) (availableFuncs []func(*rego.Rego)) { + availableFuncs = make([]func(*rego.Rego), 2) + availableFuncs[0] = rego.Module(filename, regoPolicy) + availableFuncs[1] = rego.Query(regoQuery) + extensions := regofunc.List() + for i := range extensions { + availableFuncs = append(availableFuncs, extensions[i]) + } + return +} + func (s *Service) queryCacheKey(policyName, group, version string) string { return fmt.Sprintf("%s,%s,%s", policyName, group, version) } diff --git a/internal/service/policy/service_test.go b/internal/service/policy/service_test.go index 95fead05f1e6d5a30dc3d678e960173c489ef3fa..5b6949a50f091c42308aa88fd89e69cd755f3008 100644 --- a/internal/service/policy/service_test.go +++ b/internal/service/policy/service_test.go @@ -11,7 +11,6 @@ import ( "code.vereign.com/gaiax/tsa/golib/errors" goapolicy "code.vereign.com/gaiax/tsa/policy/gen/policy" - "code.vereign.com/gaiax/tsa/policy/internal/regofunc" "code.vereign.com/gaiax/tsa/policy/internal/service/policy" "code.vereign.com/gaiax/tsa/policy/internal/service/policy/policyfakes" "code.vereign.com/gaiax/tsa/policy/internal/storage" @@ -20,8 +19,7 @@ import ( func TestNew(t *testing.T) { storage := &policyfakes.FakeStorage{} regocache := &policyfakes.FakeRegoCache{} - regofuncs := regofunc.New("https://example.com") - svc := policy.New(storage, regocache, regofuncs, zap.NewNop()) + svc := policy.New(storage, regocache, zap.NewNop()) assert.Implements(t, (*goapolicy.Service)(nil), svc) } @@ -29,7 +27,7 @@ func TestService_Evaluate(t *testing.T) { // prepare test policy source code that will be evaluated testPolicy := `package testgroup.example allow { input.msg == "yes" }` - // prepare test query that can be retrieved from rego cache + // prepare test query that can be retrieved from rego queryCache testQuery, err := rego.New( rego.Module("example.rego", testPolicy), rego.Query("data.testgroup.example"), @@ -59,7 +57,7 @@ func TestService_Evaluate(t *testing.T) { errtext string }{ { - name: "prepared query is found in cache", + name: "prepared query is found in queryCache", req: testReq(), regocache: &policyfakes.FakeRegoCache{ GetStub: func(key string) (*rego.PreparedEvalQuery, bool) { @@ -146,8 +144,7 @@ func TestService_Evaluate(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - regofuncs := regofunc.New("https://example.com") - svc := policy.New(test.storage, test.regocache, regofuncs, zap.NewNop()) + svc := policy.New(test.storage, test.regocache, zap.NewNop()) res, err := svc.Evaluate(context.Background(), test.req) if err == nil { assert.Empty(t, test.errtext) @@ -246,8 +243,7 @@ func TestService_Lock(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - regofuncs := regofunc.New("https://example.com") - svc := policy.New(test.storage, nil, regofuncs, zap.NewNop()) + svc := policy.New(test.storage, nil, zap.NewNop()) err := svc.Lock(context.Background(), test.req) if err == nil { assert.Empty(t, test.errtext) @@ -344,8 +340,7 @@ func TestService_Unlock(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - regofuncs := regofunc.New("https://example.com") - svc := policy.New(test.storage, nil, regofuncs, zap.NewNop()) + svc := policy.New(test.storage, nil, zap.NewNop()) err := svc.Unlock(context.Background(), test.req) if err == nil { assert.Empty(t, test.errtext)