diff --git a/cmd/policy/main.go b/cmd/policy/main.go index d85ff216f92623938fbc94a6ad0790a883f703c7..100261acdf3a11b0c346ead3548e52daecf52993 100644 --- a/cmd/policy/main.go +++ b/cmd/policy/main.go @@ -66,18 +66,24 @@ func main() { defer db.Disconnect(context.Background()) //nolint:errcheck // create storage - storage := storage.New(db, cfg.Mongo.DB, cfg.Mongo.Collection, logger) + storage, err := storage.New(db, cfg.Mongo.DB, cfg.Mongo.Collection, logger) + if err != nil { + logger.Fatal("error connecting to database", zap.Error(err)) + } // create rego query cache regocache := regocache.New() // register rego extension functions { - cacheFuncs := regofunc.NewCacheFuncs(cfg.Cache.Addr, httpClient()) - didResolverFuncs := regofunc.NewDIDResolverFuncs(cfg.DIDResolver.Addr, httpClient()) + httpClient := httpClient() + cacheFuncs := regofunc.NewCacheFuncs(cfg.Cache.Addr, httpClient) + didResolverFuncs := regofunc.NewDIDResolverFuncs(cfg.DIDResolver.Addr, httpClient) + taskFuncs := regofunc.NewTaskFuncs(cfg.Task.Addr, httpClient) regofunc.Register("cacheGet", rego.Function3(cacheFuncs.CacheGetFunc())) regofunc.Register("cacheSet", rego.Function4(cacheFuncs.CacheSetFunc())) - regofunc.Register("didResolve", rego.Function1(didResolverFuncs.Resolve())) + regofunc.Register("didResolve", rego.Function1(didResolverFuncs.ResolveFunc())) + regofunc.Register("taskCreate", rego.Function2(taskFuncs.CreateTaskFunc())) regofunc.Register("strictBuiltinErrors", rego.StrictBuiltinErrors(true)) } diff --git a/internal/config/config.go b/internal/config/config.go index 9a52c10a1bdbbe1e7a668326fb5e77cea00015d0..84573dbb869aeacffc13b34af7dcd6f82bfb49e1 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -6,6 +6,7 @@ type Config struct { HTTP httpConfig Mongo mongoConfig Cache cacheConfig + Task taskConfig DIDResolver didResolverConfig LogLevel string `envconfig:"LOG_LEVEL" default:"INFO"` @@ -23,6 +24,10 @@ type cacheConfig struct { Addr string `envconfig:"CACHE_ADDR" required:"true"` } +type taskConfig struct { + Addr string `envconfig:"TASK_ADDR" required:"true"` +} + type didResolverConfig struct { Addr string `envconfig:"DID_RESOLVER_ADDR" required:"true"` } diff --git a/internal/regofunc/did_resolver.go b/internal/regofunc/did_resolver.go index 45a7d2645b25d0ff2359a9a3dea0aa5162be1c2b..a4b0bf48581b244a04cec007cc4eba4eb7022b5a 100644 --- a/internal/regofunc/did_resolver.go +++ b/internal/regofunc/did_resolver.go @@ -23,7 +23,7 @@ func NewDIDResolverFuncs(resolverAddr string, httpClient *http.Client) *DIDResol } } -func (dr *DIDResolverFuncs) Resolve() (*rego.Function, rego.Builtin1) { +func (dr *DIDResolverFuncs) ResolveFunc() (*rego.Function, rego.Builtin1) { return ®o.Function{ Name: "did.resolve", Decl: types.NewFunction(types.Args(types.S), types.A), diff --git a/internal/regofunc/did_resolver_test.go b/internal/regofunc/did_resolver_test.go index cbe3156f8fb3326c6d37ecdb090080b1bdae7bdc..24155722f1e6fc8f4e8e16038c0e2134675af215 100644 --- a/internal/regofunc/did_resolver_test.go +++ b/internal/regofunc/did_resolver_test.go @@ -25,7 +25,7 @@ func TestResolveFunc(t *testing.T) { r := rego.New( rego.Query(`did.resolve("did:indy:idunion:BDrEcHc8Tb4Lb2VyQZWEDE")`), - rego.Function1(DIDResolverFuncs.Resolve()), + rego.Function1(DIDResolverFuncs.ResolveFunc()), ) resultSet, err := r.Eval(context.Background()) assert.NoError(t, err) diff --git a/internal/regofunc/task.go b/internal/regofunc/task.go new file mode 100644 index 0000000000000000000000000000000000000000..0ea857c360c58b5b765072880b900752642bea9f --- /dev/null +++ b/internal/regofunc/task.go @@ -0,0 +1,77 @@ +package regofunc + +import ( + "bytes" + "encoding/json" + "fmt" + "net/http" + "net/url" + + "github.com/open-policy-agent/opa/ast" + "github.com/open-policy-agent/opa/rego" + "github.com/open-policy-agent/opa/types" +) + +type TaskFuncs struct { + taskAddr string + httpClient *http.Client +} + +func NewTaskFuncs(taskAddr string, httpClient *http.Client) *TaskFuncs { + return &TaskFuncs{ + taskAddr: taskAddr, + httpClient: httpClient, + } +} + +// CreateTaskFunc returns a rego function for creating tasks. +func (t *TaskFuncs) CreateTaskFunc() (*rego.Function, rego.Builtin2) { + return ®o.Function{ + Name: "task.create", + Decl: types.NewFunction(types.Args(types.S, types.S), types.A), + Memoize: true, + }, + func(bctx rego.BuiltinContext, taskName, taskData *ast.Term) (*ast.Term, error) { + var name string + var data map[string]interface{} + + if err := ast.As(taskName.Value, &name); err != nil { + return nil, fmt.Errorf("invalid task name: %s", err) + } else if err = ast.As(taskData.Value, &data); err != nil { + return nil, fmt.Errorf("invalid data: %s", err) + } + + jsonData, err := json.Marshal(data) + if err != nil { + return nil, err + } + + fullURL := fmt.Sprintf("%s/v1/task/%s", t.taskAddr, name) + u, err := url.ParseRequestURI(fullURL) + if err != nil { + return nil, err + } + + req, err := http.NewRequest("POST", u.String(), bytes.NewReader(jsonData)) + if err != nil { + return nil, err + } + + resp, err := t.httpClient.Do(req.WithContext(bctx.Context)) + if err != nil { + return nil, err + } + defer resp.Body.Close() // nolint:errcheck + + if resp.StatusCode != http.StatusCreated && resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("unexpected response code: %d", resp.StatusCode) + } + + v, err := ast.ValueFromReader(resp.Body) + if err != nil { + return nil, err + } + + return ast.NewTerm(v), nil + } +} diff --git a/internal/regofunc/task_test.go b/internal/regofunc/task_test.go new file mode 100644 index 0000000000000000000000000000000000000000..de92eaabd10ca2fe74af9f064176d7ec2f5e7552 --- /dev/null +++ b/internal/regofunc/task_test.go @@ -0,0 +1,87 @@ +package regofunc_test + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + "github.com/open-policy-agent/opa/rego" + "github.com/stretchr/testify/assert" + + "code.vereign.com/gaiax/tsa/policy/internal/regofunc" +) + +func TestTaskFuncs_CreateTask(t *testing.T) { + tests := []struct { + name string + taskName interface{} + input map[string]interface{} + taskHandler func(w http.ResponseWriter, r *http.Request) + + response map[string]interface{} + errtext string + }{ + { + name: "task not found", + input: map[string]interface{}{"test": 123}, + taskHandler: func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + _, _ = w.Write([]byte(`{"error":"task not found"}`)) + }, + errtext: "task.create: unexpected response code: 404", + }, + { + name: "task service returns error", + input: map[string]interface{}{"test": 123}, + taskHandler: func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + }, + errtext: "task.create: unexpected response code: 500", + }, + { + name: "task service returns invalid JSON response", + input: nil, + taskHandler: func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("boom")) + }, + response: nil, + errtext: "task.create: invalid character", + }, + { + name: "task is created successfully", + input: nil, + taskHandler: func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"taskID":"hello"}`)) + }, + response: map[string]interface{}{"taskID": "hello"}, + errtext: "", + }, + } + + for _, test := range tests { + srv := httptest.NewServer(http.HandlerFunc(test.taskHandler)) + taskFuncs := regofunc.NewTaskFuncs(srv.URL, http.DefaultClient) + + query, err := rego.New( + rego.Query(`task.create("taskName", input)`), + rego.Function2(taskFuncs.CreateTaskFunc()), + rego.StrictBuiltinErrors(true), + ).PrepareForEval(context.Background()) + assert.NoError(t, err) + + resultSet, err := query.Eval(context.Background(), rego.EvalInput(test.input)) + if test.errtext != "" { + assert.Nil(t, resultSet) + assert.Error(t, err) + assert.Contains(t, err.Error(), test.errtext) + } else { + assert.NoError(t, err) + assert.NotEmpty(t, resultSet) + assert.NotEmpty(t, resultSet[0].Expressions) + assert.Equal(t, test.response, resultSet[0].Expressions[0].Value) + } + } +} diff --git a/internal/storage/storage.go b/internal/storage/storage.go index e87e75c6da4e98911540643b4a80fc8664041cf3..97302794587644fe8d443ffde0265a0b7fec9881 100644 --- a/internal/storage/storage.go +++ b/internal/storage/storage.go @@ -32,11 +32,15 @@ type Storage struct { logger *zap.Logger } -func New(db *mongo.Client, dbname, collection string, logger *zap.Logger) *Storage { +func New(db *mongo.Client, dbname, collection string, logger *zap.Logger) (*Storage, error) { + if err := db.Ping(context.Background(), nil); err != nil { + return nil, err + } + return &Storage{ policy: db.Database(dbname).Collection(collection), logger: logger, - } + }, nil } func (s *Storage) Policy(ctx context.Context, group, name, version string) (*Policy, error) {