Skip to content
Snippets Groups Projects
Commit e0453d30 authored by Valery Kalashnikov's avatar Valery Kalashnikov
Browse files

ref #17. Use factory approach to initialize the rego functions

parent 20142cd5
No related branches found
No related tags found
2 merge requests!12(kalashnikov) Rego cache functions,!11Rego functions for cache write and read (set and get)
Pipeline #50378 failed with stage
in 59 seconds
......@@ -9,6 +9,7 @@ import (
"time"
"github.com/kelseyhightower/envconfig"
"github.com/open-policy-agent/opa/rego"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
"go.uber.org/zap"
......@@ -71,12 +72,16 @@ func main() {
regocache := regocache.New()
// custom rego functions
regofuncs := regofunc.New(
regofuncCache := regofunc.NewCache(
cfg.Cache.Addr,
regofunc.WithHTTPClient(httpClient()),
regofunc.WithLogger(logger),
)
regofunc.Initialize("cacheGet", rego.Function3(regofuncCache.CacheGetFunc()))
regofunc.Initialize("cacheSet", rego.Function4(regofuncCache.CacheSetFunc()))
regofunc.Initialize("strictBuiltinErrors", rego.StrictBuiltinErrors(true))
// subscribe the cache for policy data changes
storage.AddPolicyChangeSubscriber(regocache)
......
package regofunc
import (
"bytes"
"encoding/json"
"fmt"
"net/http"
"github.com/open-policy-agent/opa/ast"
"github.com/open-policy-agent/opa/rego"
"github.com/open-policy-agent/opa/types"
"go.uber.org/zap"
)
type CacheRegoFunc struct {
RegoFunc
cacheAddr string
}
func NewCache(cacheAddr string, opts ...Option) *CacheRegoFunc {
rf := &CacheRegoFunc{
RegoFunc: RegoFunc{
httpClient: http.DefaultClient,
logger: zap.NewNop(),
},
cacheAddr: cacheAddr,
}
// for _, opt := range opts {
// opt(rf)
// }
// return rf
return rf
}
func (r *CacheRegoFunc) CacheGetFunc() (*rego.Function, rego.Builtin3) {
return &rego.Function{
Name: "cache.get",
Decl: types.NewFunction(types.Args(types.S, types.S, types.S), types.A),
Memoize: true,
},
func(bctx rego.BuiltinContext, a, b, c *ast.Term) (*ast.Term, error) {
var key, namespace, scope string
if err := ast.As(a.Value, &key); err != nil {
return nil, fmt.Errorf("invalid key: %s", err)
} else if err = ast.As(b.Value, &namespace); err != nil {
return nil, fmt.Errorf("invalid namespace: %s", err)
} else if err = ast.As(c.Value, &scope); err != nil {
return nil, fmt.Errorf("invalid scope: %s", err)
}
req, err := http.NewRequest("GET", r.cacheAddr+"/v1/cache", nil)
req.Header = http.Header{
"x-cache-key": []string{key},
"x-cache-namespace": []string{namespace},
"x-cache-scope": []string{scope},
}
if err != nil {
return nil, err
}
resp, err := r.httpClient.Do(req.WithContext(bctx.Context))
if err != nil {
return nil, err
}
defer resp.Body.Close() // nolint:errcheck
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusNotFound {
return nil, fmt.Errorf("unexpected response: %d %s", resp.StatusCode, resp.Status)
}
v, err := ast.ValueFromReader(resp.Body)
if err != nil {
return nil, err
}
return ast.NewTerm(v), nil
}
}
func (r *CacheRegoFunc) CacheSetFunc() (*rego.Function, rego.Builtin4) {
return &rego.Function{
Name: "cache.set",
Decl: types.NewFunction(types.Args(types.S, types.S, types.S, types.S), types.A),
Memoize: true,
},
func(bctx rego.BuiltinContext, k, n, s, d *ast.Term) (*ast.Term, error) {
var key, namespace, scope string
var data map[string]interface{}
if err := ast.As(k.Value, &key); err != nil {
return nil, fmt.Errorf("invalid key: %s", err)
} else if err = ast.As(n.Value, &namespace); err != nil {
return nil, fmt.Errorf("invalid namespace: %s", err)
} else if err = ast.As(s.Value, &scope); err != nil {
return nil, fmt.Errorf("invalid scope: %s", err)
} else if err = ast.As(d.Value, &data); err != nil {
return nil, fmt.Errorf("invalid data: %s", err)
}
jsonData, err := json.Marshal(data)
if err != nil {
return nil, err
}
req, err := http.NewRequest("POST", r.cacheAddr+"/v1/cache", bytes.NewReader(jsonData))
if err != nil {
return nil, err
}
req.Header = http.Header{
"x-cache-key": []string{key},
"x-cache-namespace": []string{namespace},
"x-cache-scope": []string{scope},
}
resp, err := r.httpClient.Do(req.WithContext(bctx.Context))
if err != nil {
return nil, err
}
defer resp.Body.Close() // nolint:errcheck
if resp.StatusCode != http.StatusCreated && resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("unexpected response code: %d", resp.StatusCode)
}
var val ast.Value
val, err = ast.InterfaceToValue("success")
if err != nil {
return nil, err
}
return ast.NewTerm(val), nil
}
}
package regofunc
import (
"fmt"
"github.com/open-policy-agent/opa/rego"
)
type regoFuncFactory func(*rego.Rego)
var regoFuncFactories = make(map[string]regoFuncFactory)
func Initialize(name string, factory regoFuncFactory) {
if factory == nil {
panic(fmt.Errorf("datastore factory %s does not exist", name))
}
_, registered := regoFuncFactories[name]
if !registered {
regoFuncFactories[name] = factory
}
}
func FuncList() []regoFuncFactory {
list := make([]regoFuncFactory, 0)
for _, value := range regoFuncFactories {
list = append(list, value)
}
return list
}
package regofunc_test
import (
"fmt"
"testing"
"code.vereign.com/gaiax/tsa/policy/internal/regofunc"
"github.com/open-policy-agent/opa/rego"
)
func TestFactory_FuncList(t *testing.T) {
regofuncCache := regofunc.NewCache(
"localhost:8080",
)
regofunc.Initialize("cacheGet", rego.Function3(regofuncCache.CacheGetFunc()))
regofunc.Initialize("cacheSet", rego.Function3(regofuncCache.CacheGetFunc()))
go func() {
l := regofunc.FuncList()
fmt.Println(l)
}()
go func() {
l := regofunc.FuncList()
fmt.Println(l)
}()
}
......@@ -4,136 +4,12 @@
package regofunc
import (
"bytes"
"encoding/json"
"fmt"
"net/http"
"github.com/open-policy-agent/opa/ast"
"github.com/open-policy-agent/opa/rego"
"github.com/open-policy-agent/opa/types"
"go.uber.org/zap"
)
type RegoFunc struct {
cacheAddr string
httpClient *http.Client
logger *zap.Logger
}
func New(cacheAddr string, opts ...Option) *RegoFunc {
rf := &RegoFunc{
cacheAddr: cacheAddr,
httpClient: http.DefaultClient,
logger: zap.NewNop(),
}
for _, opt := range opts {
opt(rf)
}
return rf
}
func (r *RegoFunc) CacheGetFunc() (*rego.Function, rego.Builtin3) {
return &rego.Function{
Name: "cache.get",
Decl: types.NewFunction(types.Args(types.S, types.S, types.S), types.A),
Memoize: true,
},
func(bctx rego.BuiltinContext, a, b, c *ast.Term) (*ast.Term, error) {
var key, namespace, scope string
if err := ast.As(a.Value, &key); err != nil {
return nil, fmt.Errorf("invalid key: %s", err)
} else if err = ast.As(b.Value, &namespace); err != nil {
return nil, fmt.Errorf("invalid namespace: %s", err)
} else if err = ast.As(c.Value, &scope); err != nil {
return nil, fmt.Errorf("invalid scope: %s", err)
}
req, err := http.NewRequest("GET", r.cacheAddr+"/v1/cache", nil)
req.Header = http.Header{
"x-cache-key": []string{key},
"x-cache-namespace": []string{namespace},
"x-cache-scope": []string{scope},
}
if err != nil {
return nil, err
}
resp, err := r.httpClient.Do(req.WithContext(bctx.Context))
if err != nil {
return nil, err
}
defer resp.Body.Close() // nolint:errcheck
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusNotFound {
return nil, fmt.Errorf("unexpected response: %d %s", resp.StatusCode, resp.Status)
}
v, err := ast.ValueFromReader(resp.Body)
if err != nil {
return nil, err
}
return ast.NewTerm(v), nil
}
}
func (r *RegoFunc) CacheSetFunc() (*rego.Function, rego.Builtin4) {
return &rego.Function{
Name: "cache.set",
Decl: types.NewFunction(types.Args(types.S, types.S, types.S, types.S), types.A),
Memoize: true,
},
func(bctx rego.BuiltinContext, k, n, s, d *ast.Term) (*ast.Term, error) {
var key, namespace, scope string
var data map[string]interface{}
if err := ast.As(k.Value, &key); err != nil {
return nil, fmt.Errorf("invalid key: %s", err)
} else if err = ast.As(n.Value, &namespace); err != nil {
return nil, fmt.Errorf("invalid namespace: %s", err)
} else if err = ast.As(s.Value, &scope); err != nil {
return nil, fmt.Errorf("invalid scope: %s", err)
} else if err = ast.As(d.Value, &data); err != nil {
return nil, fmt.Errorf("invalid data: %s", err)
}
jsonData, err := json.Marshal(data)
if err != nil {
return nil, err
}
req, err := http.NewRequest("POST", r.cacheAddr+"/v1/cache", bytes.NewReader(jsonData))
if err != nil {
return nil, err
}
req.Header = http.Header{
"x-cache-key": []string{key},
"x-cache-namespace": []string{namespace},
"x-cache-scope": []string{scope},
}
resp, err := r.httpClient.Do(req.WithContext(bctx.Context))
if err != nil {
return nil, err
}
defer resp.Body.Close() // nolint:errcheck
if resp.StatusCode != http.StatusCreated && resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("unexpected response code: %d", resp.StatusCode)
}
var val ast.Value
val, err = ast.InterfaceToValue("success")
if err != nil {
return nil, err
}
return ast.NewTerm(val), nil
}
}
package regofunc_test
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"testing"
"github.com/open-policy-agent/opa/rego"
"github.com/stretchr/testify/assert"
"code.vereign.com/gaiax/tsa/policy/internal/regofunc"
)
func TestRegoFunc_CacheGetFunc(t *testing.T) {
expected := `{"taskID":"deadbeef"}`
cacheSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = fmt.Fprint(w, expected)
}))
defer cacheSrv.Close()
regofuncs := regofunc.New(cacheSrv.URL)
r := rego.New(
rego.Query(`cache.get("open-policy-agent", "opa", "111")`),
rego.Function3(regofuncs.CacheGetFunc()),
)
resultSet, err := r.Eval(context.Background())
assert.NoError(t, err)
resultBytes, err := json.Marshal(resultSet[0].Expressions[0].Value)
assert.NoError(t, err)
assert.Equal(t, expected, string(resultBytes))
}
func TestRegoFunc_CacheSetFuncSuccess(t *testing.T) {
expected := "success"
cacheSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
expectedRequestBody := `{"test":123}`
bodyBytes, err := io.ReadAll(r.Body)
assert.NoError(t, err)
bodyString := string(bodyBytes)
if bodyString != expectedRequestBody {
assert.Equal(t, expectedRequestBody, bodyString)
}
w.WriteHeader(http.StatusCreated)
}))
defer cacheSrv.Close()
regofuncs := regofunc.New(cacheSrv.URL)
input := map[string]interface{}{"test": 123}
query, err := rego.New(
rego.Query(`cache.set("open-policy-agent", "opa", "111", input)`),
rego.Function4(regofuncs.CacheSetFunc()),
).PrepareForEval(context.Background())
assert.NoError(t, err)
resultSet, err := query.Eval(context.Background(), rego.EvalInput(input))
assert.NoError(t, err)
assert.NotEmpty(t, resultSet)
assert.NotEmpty(t, resultSet[0].Expressions)
assert.Equal(t, expected, resultSet[0].Expressions[0].Value)
}
func TestRegoFunc_CacheSetFuncError(t *testing.T) {
cacheSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
expectedRequestBody := "test"
bodyBytes, err := io.ReadAll(r.Body)
assert.NoError(t, err)
bodyString := string(bodyBytes)
assert.Equal(t, expectedRequestBody, bodyString)
w.WriteHeader(http.StatusNotFound)
}))
defer cacheSrv.Close()
regofuncs := regofunc.New(cacheSrv.URL)
r := rego.New(
rego.Query(`cache.set("open-policy-agent", "opa", "111", "test")`),
rego.Function4(regofuncs.CacheSetFunc()),
)
resultSet, err := r.Eval(context.Background())
assert.NoError(t, err)
assert.Empty(t, resultSet)
}
// import (
// "context"
// "encoding/json"
// "fmt"
// "io"
// "net/http"
// "net/http/httptest"
// "testing"
// "github.com/open-policy-agent/opa/rego"
// "github.com/stretchr/testify/assert"
// "code.vereign.com/gaiax/tsa/policy/internal/regofunc"
// )
// func TestRegoFunc_CacheGetFunc(t *testing.T) {
// expected := `{"taskID":"deadbeef"}`
// cacheSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// _, _ = fmt.Fprint(w, expected)
// }))
// defer cacheSrv.Close()
// regofuncs := regofunc.New(cacheSrv.URL)
// r := rego.New(
// rego.Query(`cache.get("open-policy-agent", "opa", "111")`),
// rego.Function3(regofuncs.CacheGetFunc()),
// )
// resultSet, err := r.Eval(context.Background())
// assert.NoError(t, err)
// resultBytes, err := json.Marshal(resultSet[0].Expressions[0].Value)
// assert.NoError(t, err)
// assert.Equal(t, expected, string(resultBytes))
// }
// func TestRegoFunc_CacheSetFuncSuccess(t *testing.T) {
// expected := "success"
// cacheSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// expectedRequestBody := `{"test":123}`
// bodyBytes, err := io.ReadAll(r.Body)
// assert.NoError(t, err)
// bodyString := string(bodyBytes)
// if bodyString != expectedRequestBody {
// assert.Equal(t, expectedRequestBody, bodyString)
// }
// w.WriteHeader(http.StatusCreated)
// }))
// defer cacheSrv.Close()
// regofuncs := regofunc.New(cacheSrv.URL)
// input := map[string]interface{}{"test": 123}
// query, err := rego.New(
// rego.Query(`cache.set("open-policy-agent", "opa", "111", input)`),
// rego.Function4(regofuncs.CacheSetFunc()),
// ).PrepareForEval(context.Background())
// assert.NoError(t, err)
// resultSet, err := query.Eval(context.Background(), rego.EvalInput(input))
// assert.NoError(t, err)
// assert.NotEmpty(t, resultSet)
// assert.NotEmpty(t, resultSet[0].Expressions)
// assert.Equal(t, expected, resultSet[0].Expressions[0].Value)
// }
// func TestRegoFunc_CacheSetFuncError(t *testing.T) {
// cacheSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// expectedRequestBody := "test"
// bodyBytes, err := io.ReadAll(r.Body)
// assert.NoError(t, err)
// bodyString := string(bodyBytes)
// assert.Equal(t, expectedRequestBody, bodyString)
// w.WriteHeader(http.StatusNotFound)
// }))
// defer cacheSrv.Close()
// regofuncs := regofunc.New(cacheSrv.URL)
// r := rego.New(
// rego.Query(`cache.set("open-policy-agent", "opa", "111", "test")`),
// rego.Function4(regofuncs.CacheSetFunc()),
// )
// resultSet, err := r.Eval(context.Background())
// assert.NoError(t, err)
// assert.Empty(t, resultSet)
// }
......@@ -173,11 +173,7 @@ func (s *Service) prepareQuery(ctx context.Context, policyName, group, version s
regoQuery := fmt.Sprintf("data.%s.%s", group, policyName)
newQuery, err := rego.New(
rego.Module(pol.Filename, pol.Rego),
rego.Query(regoQuery),
rego.Function3(s.regoFunc.CacheGetFunc()),
rego.Function4(s.regoFunc.CacheSetFunc()),
rego.StrictBuiltinErrors(true),
buildRegoArgs(pol.Filename, pol.Rego, regoQuery)...,
).PrepareForEval(ctx)
if err != nil {
return nil, errors.New("error preparing rego query", err)
......@@ -188,6 +184,17 @@ func (s *Service) prepareQuery(ctx context.Context, policyName, group, version s
return &newQuery, nil
}
func buildRegoArgs(filename, regoField, regoQuery string) (availableFuncs []func(*rego.Rego)) {
availableFuncs = make([]func(*rego.Rego), 0, 0)
availableFuncs[0] = rego.Module(filename, regoField)
availableFuncs[1] = rego.Query(regoQuery)
extensions := regofunc.FuncList()
for k := range availableFuncs {
availableFuncs = append(availableFuncs, extensions[k])
}
return
}
func (s *Service) queryCacheKey(policyName, group, version string) string {
return fmt.Sprintf("%s,%s,%s", policyName, group, version)
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment