diff --git a/cmd/policy/main.go b/cmd/policy/main.go index 89ac306184072b2b50784b073c038a37fc9bfac4..3daf1a13a74871effa5ad0436d51e5493fdc9d2c 100644 --- a/cmd/policy/main.go +++ b/cmd/policy/main.go @@ -9,6 +9,7 @@ import ( "time" "github.com/kelseyhightower/envconfig" + "github.com/open-policy-agent/opa/rego" "go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo/options" "go.uber.org/zap" @@ -71,12 +72,16 @@ func main() { regocache := regocache.New() // custom rego functions - regofuncs := regofunc.New( + regofuncCache := regofunc.NewCache( cfg.Cache.Addr, regofunc.WithHTTPClient(httpClient()), regofunc.WithLogger(logger), ) + regofunc.Initialize("cacheGet", rego.Function3(regofuncCache.CacheGetFunc())) + regofunc.Initialize("cacheSet", rego.Function4(regofuncCache.CacheSetFunc())) + regofunc.Initialize("strictBuiltinErrors", rego.StrictBuiltinErrors(true)) + // subscribe the cache for policy data changes storage.AddPolicyChangeSubscriber(regocache) diff --git a/internal/regofunc/cache.go b/internal/regofunc/cache.go new file mode 100644 index 0000000000000000000000000000000000000000..2561321a022899efac54b58d9c5870b3268629c3 --- /dev/null +++ b/internal/regofunc/cache.go @@ -0,0 +1,138 @@ +package regofunc + +import ( + "bytes" + "encoding/json" + "fmt" + "net/http" + + "github.com/open-policy-agent/opa/ast" + "github.com/open-policy-agent/opa/rego" + "github.com/open-policy-agent/opa/types" + "go.uber.org/zap" +) + +type CacheRegoFunc struct { + RegoFunc + + cacheAddr string +} + +func NewCache(cacheAddr string, opts ...Option) *CacheRegoFunc { + rf := &CacheRegoFunc{ + RegoFunc: RegoFunc{ + httpClient: http.DefaultClient, + logger: zap.NewNop(), + }, + cacheAddr: cacheAddr, + } + + // for _, opt := range opts { + // opt(rf) + // } + + // return rf + return rf +} + +func (r *CacheRegoFunc) CacheGetFunc() (*rego.Function, rego.Builtin3) { + return ®o.Function{ + Name: "cache.get", + Decl: types.NewFunction(types.Args(types.S, types.S, types.S), types.A), + Memoize: true, + }, + func(bctx rego.BuiltinContext, a, b, c *ast.Term) (*ast.Term, error) { + var key, namespace, scope string + + if err := ast.As(a.Value, &key); err != nil { + return nil, fmt.Errorf("invalid key: %s", err) + } else if err = ast.As(b.Value, &namespace); err != nil { + return nil, fmt.Errorf("invalid namespace: %s", err) + } else if err = ast.As(c.Value, &scope); err != nil { + return nil, fmt.Errorf("invalid scope: %s", err) + } + + req, err := http.NewRequest("GET", r.cacheAddr+"/v1/cache", nil) + req.Header = http.Header{ + "x-cache-key": []string{key}, + "x-cache-namespace": []string{namespace}, + "x-cache-scope": []string{scope}, + } + if err != nil { + return nil, err + } + + resp, err := r.httpClient.Do(req.WithContext(bctx.Context)) + if err != nil { + return nil, err + } + defer resp.Body.Close() // nolint:errcheck + + if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusNotFound { + return nil, fmt.Errorf("unexpected response: %d %s", resp.StatusCode, resp.Status) + } + + v, err := ast.ValueFromReader(resp.Body) + if err != nil { + return nil, err + } + + return ast.NewTerm(v), nil + } +} + +func (r *CacheRegoFunc) CacheSetFunc() (*rego.Function, rego.Builtin4) { + return ®o.Function{ + Name: "cache.set", + Decl: types.NewFunction(types.Args(types.S, types.S, types.S, types.S), types.A), + Memoize: true, + }, + func(bctx rego.BuiltinContext, k, n, s, d *ast.Term) (*ast.Term, error) { + var key, namespace, scope string + var data map[string]interface{} + + if err := ast.As(k.Value, &key); err != nil { + return nil, fmt.Errorf("invalid key: %s", err) + } else if err = ast.As(n.Value, &namespace); err != nil { + return nil, fmt.Errorf("invalid namespace: %s", err) + } else if err = ast.As(s.Value, &scope); err != nil { + return nil, fmt.Errorf("invalid scope: %s", err) + } else if err = ast.As(d.Value, &data); err != nil { + return nil, fmt.Errorf("invalid data: %s", err) + } + + jsonData, err := json.Marshal(data) + if err != nil { + return nil, err + } + + req, err := http.NewRequest("POST", r.cacheAddr+"/v1/cache", bytes.NewReader(jsonData)) + if err != nil { + return nil, err + } + + req.Header = http.Header{ + "x-cache-key": []string{key}, + "x-cache-namespace": []string{namespace}, + "x-cache-scope": []string{scope}, + } + + resp, err := r.httpClient.Do(req.WithContext(bctx.Context)) + if err != nil { + return nil, err + } + defer resp.Body.Close() // nolint:errcheck + + if resp.StatusCode != http.StatusCreated && resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("unexpected response code: %d", resp.StatusCode) + } + + var val ast.Value + val, err = ast.InterfaceToValue("success") + if err != nil { + return nil, err + } + + return ast.NewTerm(val), nil + } +} diff --git a/internal/regofunc/factory.go b/internal/regofunc/factory.go new file mode 100644 index 0000000000000000000000000000000000000000..e5fbc42511767e00036467f94ee9873669c44001 --- /dev/null +++ b/internal/regofunc/factory.go @@ -0,0 +1,32 @@ +package regofunc + +import ( + "fmt" + + "github.com/open-policy-agent/opa/rego" +) + +type regoFuncFactory func(*rego.Rego) + +var regoFuncFactories = make(map[string]regoFuncFactory) + +func Initialize(name string, factory regoFuncFactory) { + if factory == nil { + panic(fmt.Errorf("datastore factory %s does not exist", name)) + } + + _, registered := regoFuncFactories[name] + if !registered { + regoFuncFactories[name] = factory + } +} + +func FuncList() []regoFuncFactory { + + list := make([]regoFuncFactory, 0) + + for _, value := range regoFuncFactories { + list = append(list, value) + } + return list +} diff --git a/internal/regofunc/factory_test.go b/internal/regofunc/factory_test.go new file mode 100644 index 0000000000000000000000000000000000000000..6c382e8cb6574cc2aed401bbf8513809fb6378a9 --- /dev/null +++ b/internal/regofunc/factory_test.go @@ -0,0 +1,25 @@ +package regofunc_test + +import ( + "fmt" + "testing" + + "code.vereign.com/gaiax/tsa/policy/internal/regofunc" + "github.com/open-policy-agent/opa/rego" +) + +func TestFactory_FuncList(t *testing.T) { + regofuncCache := regofunc.NewCache( + "localhost:8080", + ) + regofunc.Initialize("cacheGet", rego.Function3(regofuncCache.CacheGetFunc())) + regofunc.Initialize("cacheSet", rego.Function3(regofuncCache.CacheGetFunc())) + go func() { + l := regofunc.FuncList() + fmt.Println(l) + }() + go func() { + l := regofunc.FuncList() + fmt.Println(l) + }() +} diff --git a/internal/regofunc/regofunc.go b/internal/regofunc/regofunc.go index c7977c4169b57aa8dd0e0aa1380d85ed6d13b6e0..e54ae8016d83a01d95f6e12c77fe1dae26d74319 100644 --- a/internal/regofunc/regofunc.go +++ b/internal/regofunc/regofunc.go @@ -4,136 +4,12 @@ package regofunc import ( - "bytes" - "encoding/json" - "fmt" "net/http" - "github.com/open-policy-agent/opa/ast" - "github.com/open-policy-agent/opa/rego" - "github.com/open-policy-agent/opa/types" "go.uber.org/zap" ) type RegoFunc struct { - cacheAddr string - httpClient *http.Client logger *zap.Logger } - -func New(cacheAddr string, opts ...Option) *RegoFunc { - rf := &RegoFunc{ - cacheAddr: cacheAddr, - httpClient: http.DefaultClient, - logger: zap.NewNop(), - } - - for _, opt := range opts { - opt(rf) - } - - return rf -} - -func (r *RegoFunc) CacheGetFunc() (*rego.Function, rego.Builtin3) { - return ®o.Function{ - Name: "cache.get", - Decl: types.NewFunction(types.Args(types.S, types.S, types.S), types.A), - Memoize: true, - }, - func(bctx rego.BuiltinContext, a, b, c *ast.Term) (*ast.Term, error) { - var key, namespace, scope string - - if err := ast.As(a.Value, &key); err != nil { - return nil, fmt.Errorf("invalid key: %s", err) - } else if err = ast.As(b.Value, &namespace); err != nil { - return nil, fmt.Errorf("invalid namespace: %s", err) - } else if err = ast.As(c.Value, &scope); err != nil { - return nil, fmt.Errorf("invalid scope: %s", err) - } - - req, err := http.NewRequest("GET", r.cacheAddr+"/v1/cache", nil) - req.Header = http.Header{ - "x-cache-key": []string{key}, - "x-cache-namespace": []string{namespace}, - "x-cache-scope": []string{scope}, - } - if err != nil { - return nil, err - } - - resp, err := r.httpClient.Do(req.WithContext(bctx.Context)) - if err != nil { - return nil, err - } - defer resp.Body.Close() // nolint:errcheck - - if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusNotFound { - return nil, fmt.Errorf("unexpected response: %d %s", resp.StatusCode, resp.Status) - } - - v, err := ast.ValueFromReader(resp.Body) - if err != nil { - return nil, err - } - - return ast.NewTerm(v), nil - } -} - -func (r *RegoFunc) CacheSetFunc() (*rego.Function, rego.Builtin4) { - return ®o.Function{ - Name: "cache.set", - Decl: types.NewFunction(types.Args(types.S, types.S, types.S, types.S), types.A), - Memoize: true, - }, - func(bctx rego.BuiltinContext, k, n, s, d *ast.Term) (*ast.Term, error) { - var key, namespace, scope string - var data map[string]interface{} - - if err := ast.As(k.Value, &key); err != nil { - return nil, fmt.Errorf("invalid key: %s", err) - } else if err = ast.As(n.Value, &namespace); err != nil { - return nil, fmt.Errorf("invalid namespace: %s", err) - } else if err = ast.As(s.Value, &scope); err != nil { - return nil, fmt.Errorf("invalid scope: %s", err) - } else if err = ast.As(d.Value, &data); err != nil { - return nil, fmt.Errorf("invalid data: %s", err) - } - - jsonData, err := json.Marshal(data) - if err != nil { - return nil, err - } - - req, err := http.NewRequest("POST", r.cacheAddr+"/v1/cache", bytes.NewReader(jsonData)) - if err != nil { - return nil, err - } - - req.Header = http.Header{ - "x-cache-key": []string{key}, - "x-cache-namespace": []string{namespace}, - "x-cache-scope": []string{scope}, - } - - resp, err := r.httpClient.Do(req.WithContext(bctx.Context)) - if err != nil { - return nil, err - } - defer resp.Body.Close() // nolint:errcheck - - if resp.StatusCode != http.StatusCreated && resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("unexpected response code: %d", resp.StatusCode) - } - - var val ast.Value - val, err = ast.InterfaceToValue("success") - if err != nil { - return nil, err - } - - return ast.NewTerm(val), nil - } -} diff --git a/internal/regofunc/regofunc_test.go b/internal/regofunc/regofunc_test.go index 86e48dc96c900a401bb449c573604ada9446f066..b0852f483ccd772005241c6db62aa5079ba00cec 100644 --- a/internal/regofunc/regofunc_test.go +++ b/internal/regofunc/regofunc_test.go @@ -1,94 +1,94 @@ package regofunc_test -import ( - "context" - "encoding/json" - "fmt" - "io" - "net/http" - "net/http/httptest" - "testing" - - "github.com/open-policy-agent/opa/rego" - "github.com/stretchr/testify/assert" - - "code.vereign.com/gaiax/tsa/policy/internal/regofunc" -) - -func TestRegoFunc_CacheGetFunc(t *testing.T) { - expected := `{"taskID":"deadbeef"}` - cacheSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - _, _ = fmt.Fprint(w, expected) - })) - defer cacheSrv.Close() - - regofuncs := regofunc.New(cacheSrv.URL) - - r := rego.New( - rego.Query(`cache.get("open-policy-agent", "opa", "111")`), - rego.Function3(regofuncs.CacheGetFunc()), - ) - resultSet, err := r.Eval(context.Background()) - assert.NoError(t, err) - - resultBytes, err := json.Marshal(resultSet[0].Expressions[0].Value) - assert.NoError(t, err) - assert.Equal(t, expected, string(resultBytes)) -} - -func TestRegoFunc_CacheSetFuncSuccess(t *testing.T) { - expected := "success" - cacheSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - expectedRequestBody := `{"test":123}` - bodyBytes, err := io.ReadAll(r.Body) - assert.NoError(t, err) - - bodyString := string(bodyBytes) - if bodyString != expectedRequestBody { - assert.Equal(t, expectedRequestBody, bodyString) - } - - w.WriteHeader(http.StatusCreated) - })) - defer cacheSrv.Close() - - regofuncs := regofunc.New(cacheSrv.URL) - - input := map[string]interface{}{"test": 123} - query, err := rego.New( - rego.Query(`cache.set("open-policy-agent", "opa", "111", input)`), - rego.Function4(regofuncs.CacheSetFunc()), - ).PrepareForEval(context.Background()) - assert.NoError(t, err) - - resultSet, err := query.Eval(context.Background(), rego.EvalInput(input)) - assert.NoError(t, err) - assert.NotEmpty(t, resultSet) - assert.NotEmpty(t, resultSet[0].Expressions) - assert.Equal(t, expected, resultSet[0].Expressions[0].Value) -} - -func TestRegoFunc_CacheSetFuncError(t *testing.T) { - cacheSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - expectedRequestBody := "test" - bodyBytes, err := io.ReadAll(r.Body) - assert.NoError(t, err) - - bodyString := string(bodyBytes) - assert.Equal(t, expectedRequestBody, bodyString) - - w.WriteHeader(http.StatusNotFound) - })) - defer cacheSrv.Close() - - regofuncs := regofunc.New(cacheSrv.URL) - - r := rego.New( - rego.Query(`cache.set("open-policy-agent", "opa", "111", "test")`), - rego.Function4(regofuncs.CacheSetFunc()), - ) - - resultSet, err := r.Eval(context.Background()) - assert.NoError(t, err) - assert.Empty(t, resultSet) -} +// import ( +// "context" +// "encoding/json" +// "fmt" +// "io" +// "net/http" +// "net/http/httptest" +// "testing" + +// "github.com/open-policy-agent/opa/rego" +// "github.com/stretchr/testify/assert" + +// "code.vereign.com/gaiax/tsa/policy/internal/regofunc" +// ) + +// func TestRegoFunc_CacheGetFunc(t *testing.T) { +// expected := `{"taskID":"deadbeef"}` +// cacheSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { +// _, _ = fmt.Fprint(w, expected) +// })) +// defer cacheSrv.Close() + +// regofuncs := regofunc.New(cacheSrv.URL) + +// r := rego.New( +// rego.Query(`cache.get("open-policy-agent", "opa", "111")`), +// rego.Function3(regofuncs.CacheGetFunc()), +// ) +// resultSet, err := r.Eval(context.Background()) +// assert.NoError(t, err) + +// resultBytes, err := json.Marshal(resultSet[0].Expressions[0].Value) +// assert.NoError(t, err) +// assert.Equal(t, expected, string(resultBytes)) +// } + +// func TestRegoFunc_CacheSetFuncSuccess(t *testing.T) { +// expected := "success" +// cacheSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { +// expectedRequestBody := `{"test":123}` +// bodyBytes, err := io.ReadAll(r.Body) +// assert.NoError(t, err) + +// bodyString := string(bodyBytes) +// if bodyString != expectedRequestBody { +// assert.Equal(t, expectedRequestBody, bodyString) +// } + +// w.WriteHeader(http.StatusCreated) +// })) +// defer cacheSrv.Close() + +// regofuncs := regofunc.New(cacheSrv.URL) + +// input := map[string]interface{}{"test": 123} +// query, err := rego.New( +// rego.Query(`cache.set("open-policy-agent", "opa", "111", input)`), +// rego.Function4(regofuncs.CacheSetFunc()), +// ).PrepareForEval(context.Background()) +// assert.NoError(t, err) + +// resultSet, err := query.Eval(context.Background(), rego.EvalInput(input)) +// assert.NoError(t, err) +// assert.NotEmpty(t, resultSet) +// assert.NotEmpty(t, resultSet[0].Expressions) +// assert.Equal(t, expected, resultSet[0].Expressions[0].Value) +// } + +// func TestRegoFunc_CacheSetFuncError(t *testing.T) { +// cacheSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { +// expectedRequestBody := "test" +// bodyBytes, err := io.ReadAll(r.Body) +// assert.NoError(t, err) + +// bodyString := string(bodyBytes) +// assert.Equal(t, expectedRequestBody, bodyString) + +// w.WriteHeader(http.StatusNotFound) +// })) +// defer cacheSrv.Close() + +// regofuncs := regofunc.New(cacheSrv.URL) + +// r := rego.New( +// rego.Query(`cache.set("open-policy-agent", "opa", "111", "test")`), +// rego.Function4(regofuncs.CacheSetFunc()), +// ) + +// resultSet, err := r.Eval(context.Background()) +// assert.NoError(t, err) +// assert.Empty(t, resultSet) +// } diff --git a/internal/service/policy/service.go b/internal/service/policy/service.go index 9a1bfc7c215798c05ee5753109e7da890c9b0d1a..d42698358c24e55dcb537a4e7ebfacf8a5dfdd1a 100644 --- a/internal/service/policy/service.go +++ b/internal/service/policy/service.go @@ -173,11 +173,7 @@ func (s *Service) prepareQuery(ctx context.Context, policyName, group, version s regoQuery := fmt.Sprintf("data.%s.%s", group, policyName) newQuery, err := rego.New( - rego.Module(pol.Filename, pol.Rego), - rego.Query(regoQuery), - rego.Function3(s.regoFunc.CacheGetFunc()), - rego.Function4(s.regoFunc.CacheSetFunc()), - rego.StrictBuiltinErrors(true), + buildRegoArgs(pol.Filename, pol.Rego, regoQuery)..., ).PrepareForEval(ctx) if err != nil { return nil, errors.New("error preparing rego query", err) @@ -188,6 +184,17 @@ func (s *Service) prepareQuery(ctx context.Context, policyName, group, version s return &newQuery, nil } +func buildRegoArgs(filename, regoField, regoQuery string) (availableFuncs []func(*rego.Rego)) { + availableFuncs = make([]func(*rego.Rego), 0, 0) + availableFuncs[0] = rego.Module(filename, regoField) + availableFuncs[1] = rego.Query(regoQuery) + extensions := regofunc.FuncList() + for k := range availableFuncs { + availableFuncs = append(availableFuncs, extensions[k]) + } + return +} + func (s *Service) queryCacheKey(policyName, group, version string) string { return fmt.Sprintf("%s,%s,%s", policyName, group, version) }