diff --git a/regoext/cache.go b/regoext/cache.go new file mode 100644 index 0000000000000000000000000000000000000000..60ea23ce6486f3e788e21a9bacd7eaeeacd39cbc --- /dev/null +++ b/regoext/cache.go @@ -0,0 +1,134 @@ +package regoext + +import ( + "bytes" + "fmt" + "net/http" + + "github.com/open-policy-agent/opa/ast" + "github.com/open-policy-agent/opa/rego" + "github.com/open-policy-agent/opa/types" +) + +const ( + Success string = "success" +) + +type CacheParams struct { + Key string + Namespace string + Scope string +} + +type CacheExt struct { + path string +} + +func NewCacheExt(path string) *CacheExt { + return &CacheExt{path: path} +} + +func (ce *CacheExt) GetCacheFunc() (*rego.Function, rego.Builtin3) { + return ®o.Function{ + Name: "cache.get", + Decl: types.NewFunction(types.Args(types.S, types.S, types.S), types.A), + Memoize: true, + }, + func(bctx rego.BuiltinContext, a, b, c *ast.Term) (*ast.Term, error) { + + var key, namespace, scope string + + if err := ast.As(a.Value, &key); err != nil { + return nil, err + } else if err = ast.As(b.Value, &namespace); err != nil { + return nil, err + } else if err = ast.As(c.Value, &scope); err != nil { + return nil, err + } + + req, err := http.NewRequest("GET", ce.path+"/v1/cache", nil) + req.Header = http.Header{ + "x-cache-key": []string{key}, + "x-cache-namespace": []string{namespace}, + "x-cache-scope": []string{scope}, + } + if err != nil { + return nil, err + } + + resp, err := http.DefaultClient.Do(req.WithContext(bctx.Context)) + if err != nil { + return nil, err + } + + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf(resp.Status) + } + + v, err := ast.ValueFromReader(resp.Body) + if err != nil { + return nil, err + } + + return ast.NewTerm(v), nil + } +} + +func (ce *CacheExt) SetCacheFunc() (*rego.Function, rego.Builtin4) { + return ®o.Function{ + Name: "cache.set", + Decl: types.NewFunction(types.Args(types.S, types.S, types.S, types.S), types.A), + Memoize: true, + }, + func(bctx rego.BuiltinContext, k, n, s, d *ast.Term) (*ast.Term, error) { + var key, namespace, scope, data string + + if err := ast.As(k.Value, &key); err != nil { + return nil, err + } else if err = ast.As(n.Value, &namespace); err != nil { + return nil, err + } else if err = ast.As(s.Value, &scope); err != nil { + return nil, err + } else if err = ast.As(d.Value, &data); err != nil { + return nil, err + } + + type Response struct { + Result string `json:"result"` + } + r := &Response{Success} + + payloadBuf := bytes.NewBufferString(data) + + req, err := http.NewRequest("POST", ce.path+"/v1/cache", payloadBuf) + req.Header = http.Header{ + "x-cache-key": []string{key}, + "x-cache-namespace": []string{namespace}, + "x-cache-scope": []string{scope}, + } + if err != nil { + return nil, err + } + + resp, err := http.DefaultClient.Do(req.WithContext(bctx.Context)) + if err != nil { + return nil, err + } + + defer resp.Body.Close() + + if resp.StatusCode != http.StatusCreated { + return nil, err + } + + var val ast.Value + val, err = ast.InterfaceToValue(r) + if err != nil { + return nil, err + } + + return ast.NewTerm(val), nil + } +} diff --git a/regoext/cache_test.go b/regoext/cache_test.go new file mode 100644 index 0000000000000000000000000000000000000000..b929377a34243db19bb06dc27c1f21fbc5b2bc3d --- /dev/null +++ b/regoext/cache_test.go @@ -0,0 +1,129 @@ +package regoext_test + +import ( + "context" + "encoding/json" + "fmt" + "io" + "log" + "net/http" + "net/http/httptest" + "regexp" + "testing" + + "code.vereign.com/gaiax/tsa/policy/regoext" + "github.com/open-policy-agent/opa/rego" +) + +func TestCacheExt_GetCacheFunc(t *testing.T) { + expected := "{}" + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, expected) + })) + defer srv.Close() + + cache := regoext.NewCacheExt(srv.URL) + + r := rego.New( + rego.Query(`cache.get("open-policy-agent", "opa", "111")`), + rego.Function3(cache.GetCacheFunc()), + ) + + rs, err := r.Eval(context.Background()) + + if err != nil { + t.Errorf("unexpected error, %v", err) + return + } + + bs, err := json.MarshalIndent(rs[0].Expressions[0].Value, "", " ") + if err != nil { + t.Errorf("unexpected error, %v", err) + return + } + if string(bs) != expected { + t.Errorf("expected %s, got %s", expected, string(bs)) + } +} + +func TestCacheExt_SetCacheFuncSuccess(t *testing.T) { + expected := `{ "result": "success" }` + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + expectedRequestBody := "test" + w.WriteHeader(http.StatusCreated) + fmt.Fprint(w, "") + + bodyBytes, err := io.ReadAll(r.Body) + if err != nil { + log.Fatal(err) + } + bodyString := string(bodyBytes) + if bodyString != expectedRequestBody { + t.Errorf("unexpected body string, expected %s, got %s", expectedRequestBody, bodyString) + } + })) + defer srv.Close() + + cache := regoext.NewCacheExt(srv.URL) + + r := rego.New( + rego.Query(`cache.set("open-policy-agent", "opa", "111", "test")`), + rego.Function4(cache.SetCacheFunc()), + ) + + rs, err := r.Eval(context.Background()) + + if err != nil { + t.Errorf("unexpected error, %v", err) + return + } + + bs, err := json.MarshalIndent(rs[0].Expressions[0].Value, "", " ") + if err != nil { + t.Errorf("unexpected error, %v", err) + return + } + + re := regexp.MustCompile(`(\s+)|(\n)+`) + s := re.ReplaceAllString(string(bs), " ") + if s != expected { + t.Errorf("unexpected result, expected %s, got %s", expected, s) + } +} + +func TestCacheExt_SetCacheFuncError(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + expectedRequestBody := "test" + w.WriteHeader(http.StatusNotFound) + fmt.Fprint(w, "") + + bodyBytes, err := io.ReadAll(r.Body) + if err != nil { + log.Fatal(err) + } + bodyString := string(bodyBytes) + if bodyString != expectedRequestBody { + t.Errorf("unexpected body string, expected %s, got %s", expectedRequestBody, bodyString) + } + })) + defer srv.Close() + + cache := regoext.NewCacheExt(srv.URL) + + r := rego.New( + rego.Query(`cache.set("open-policy-agent", "opa", "111", "test")`), + rego.Function4(cache.SetCacheFunc()), + ) + + rs, err := r.Eval(context.Background()) + + if err != nil { + t.Errorf("unexpected error, %v", err) + return + } + + if len(rs) != 0 { + t.Errorf("result set should be empty, got %v", rs) + } +}