diff --git a/regoext/cache.go b/regoext/cache.go index 0e74f679c9558e08c2453b398d988a881ab9140a..60ea23ce6486f3e788e21a9bacd7eaeeacd39cbc 100644 --- a/regoext/cache.go +++ b/regoext/cache.go @@ -1,6 +1,7 @@ package regoext import ( + "bytes" "fmt" "net/http" @@ -9,6 +10,16 @@ import ( "github.com/open-policy-agent/opa/types" ) +const ( + Success string = "success" +) + +type CacheParams struct { + Key string + Namespace string + Scope string +} + type CacheExt struct { path string } @@ -64,3 +75,60 @@ func (ce *CacheExt) GetCacheFunc() (*rego.Function, rego.Builtin3) { return ast.NewTerm(v), nil } } + +func (ce *CacheExt) SetCacheFunc() (*rego.Function, rego.Builtin4) { + return ®o.Function{ + Name: "cache.set", + Decl: types.NewFunction(types.Args(types.S, types.S, types.S, types.S), types.A), + Memoize: true, + }, + func(bctx rego.BuiltinContext, k, n, s, d *ast.Term) (*ast.Term, error) { + var key, namespace, scope, data string + + if err := ast.As(k.Value, &key); err != nil { + return nil, err + } else if err = ast.As(n.Value, &namespace); err != nil { + return nil, err + } else if err = ast.As(s.Value, &scope); err != nil { + return nil, err + } else if err = ast.As(d.Value, &data); err != nil { + return nil, err + } + + type Response struct { + Result string `json:"result"` + } + r := &Response{Success} + + payloadBuf := bytes.NewBufferString(data) + + req, err := http.NewRequest("POST", ce.path+"/v1/cache", payloadBuf) + req.Header = http.Header{ + "x-cache-key": []string{key}, + "x-cache-namespace": []string{namespace}, + "x-cache-scope": []string{scope}, + } + if err != nil { + return nil, err + } + + resp, err := http.DefaultClient.Do(req.WithContext(bctx.Context)) + if err != nil { + return nil, err + } + + defer resp.Body.Close() + + if resp.StatusCode != http.StatusCreated { + return nil, err + } + + var val ast.Value + val, err = ast.InterfaceToValue(r) + if err != nil { + return nil, err + } + + return ast.NewTerm(val), nil + } +} diff --git a/regoext/cache_test.go b/regoext/cache_test.go index 5a897b866a2a4bd2d3002f7a8775777a10295496..b929377a34243db19bb06dc27c1f21fbc5b2bc3d 100644 --- a/regoext/cache_test.go +++ b/regoext/cache_test.go @@ -4,8 +4,11 @@ import ( "context" "encoding/json" "fmt" + "io" + "log" "net/http" "net/http/httptest" + "regexp" "testing" "code.vereign.com/gaiax/tsa/policy/regoext" @@ -43,3 +46,84 @@ func TestCacheExt_GetCacheFunc(t *testing.T) { t.Errorf("expected %s, got %s", expected, string(bs)) } } + +func TestCacheExt_SetCacheFuncSuccess(t *testing.T) { + expected := `{ "result": "success" }` + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + expectedRequestBody := "test" + w.WriteHeader(http.StatusCreated) + fmt.Fprint(w, "") + + bodyBytes, err := io.ReadAll(r.Body) + if err != nil { + log.Fatal(err) + } + bodyString := string(bodyBytes) + if bodyString != expectedRequestBody { + t.Errorf("unexpected body string, expected %s, got %s", expectedRequestBody, bodyString) + } + })) + defer srv.Close() + + cache := regoext.NewCacheExt(srv.URL) + + r := rego.New( + rego.Query(`cache.set("open-policy-agent", "opa", "111", "test")`), + rego.Function4(cache.SetCacheFunc()), + ) + + rs, err := r.Eval(context.Background()) + + if err != nil { + t.Errorf("unexpected error, %v", err) + return + } + + bs, err := json.MarshalIndent(rs[0].Expressions[0].Value, "", " ") + if err != nil { + t.Errorf("unexpected error, %v", err) + return + } + + re := regexp.MustCompile(`(\s+)|(\n)+`) + s := re.ReplaceAllString(string(bs), " ") + if s != expected { + t.Errorf("unexpected result, expected %s, got %s", expected, s) + } +} + +func TestCacheExt_SetCacheFuncError(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + expectedRequestBody := "test" + w.WriteHeader(http.StatusNotFound) + fmt.Fprint(w, "") + + bodyBytes, err := io.ReadAll(r.Body) + if err != nil { + log.Fatal(err) + } + bodyString := string(bodyBytes) + if bodyString != expectedRequestBody { + t.Errorf("unexpected body string, expected %s, got %s", expectedRequestBody, bodyString) + } + })) + defer srv.Close() + + cache := regoext.NewCacheExt(srv.URL) + + r := rego.New( + rego.Query(`cache.set("open-policy-agent", "opa", "111", "test")`), + rego.Function4(cache.SetCacheFunc()), + ) + + rs, err := r.Eval(context.Background()) + + if err != nil { + t.Errorf("unexpected error, %v", err) + return + } + + if len(rs) != 0 { + t.Errorf("result set should be empty, got %v", rs) + } +}