diff --git a/cmd/policy/main.go b/cmd/policy/main.go index 1eec5678d6967bcb8ff3652f59ed47e16b02db50..26690bd44e4d29ef04a8378ca7024c1c1bac75a6 100644 --- a/cmd/policy/main.go +++ b/cmd/policy/main.go @@ -86,6 +86,7 @@ func main() { regofunc.Register("cacheSet", rego.Function4(cacheFuncs.CacheSetFunc())) regofunc.Register("didResolve", rego.Function1(didResolverFuncs.ResolveFunc())) regofunc.Register("taskCreate", rego.Function2(taskFuncs.CreateTaskFunc())) + regofunc.Register("taskListCreate", rego.Function2(taskFuncs.CreateTaskListFunc())) regofunc.Register("strictBuiltinErrors", rego.StrictBuiltinErrors(true)) } diff --git a/internal/regofunc/task.go b/internal/regofunc/task.go index 0ea857c360c58b5b765072880b900752642bea9f..d23565588f75cd0f398e8b29cbbb1909a34a572b 100644 --- a/internal/regofunc/task.go +++ b/internal/regofunc/task.go @@ -75,3 +75,55 @@ func (t *TaskFuncs) CreateTaskFunc() (*rego.Function, rego.Builtin2) { return ast.NewTerm(v), nil } } + +// CreateTaskListFunc returns a rego function for creating task lists. +func (t *TaskFuncs) CreateTaskListFunc() (*rego.Function, rego.Builtin2) { + return ®o.Function{ + Name: "tasklist.create", + Decl: types.NewFunction(types.Args(types.S, types.S), types.A), + Memoize: true, + }, + func(bctx rego.BuiltinContext, taskListName, taskListData *ast.Term) (*ast.Term, error) { + var name string + var data map[string]interface{} + + if err := ast.As(taskListName.Value, &name); err != nil { + return nil, fmt.Errorf("invalid taskList name: %s", err) + } else if err = ast.As(taskListData.Value, &data); err != nil { + return nil, fmt.Errorf("invalid data: %s", err) + } + + jsonData, err := json.Marshal(data) + if err != nil { + return nil, err + } + + fullURL := fmt.Sprintf("%s/v1/taskList/%s", t.taskAddr, name) + u, err := url.ParseRequestURI(fullURL) + if err != nil { + return nil, err + } + + req, err := http.NewRequest("POST", u.String(), bytes.NewReader(jsonData)) + if err != nil { + return nil, err + } + + resp, err := t.httpClient.Do(req.WithContext(bctx.Context)) + if err != nil { + return nil, err + } + defer resp.Body.Close() // nolint:errcheck + + if resp.StatusCode != http.StatusCreated && resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("unexpected response code: %d", resp.StatusCode) + } + + v, err := ast.ValueFromReader(resp.Body) + if err != nil { + return nil, err + } + + return ast.NewTerm(v), nil + } +} diff --git a/internal/regofunc/task_test.go b/internal/regofunc/task_test.go index de92eaabd10ca2fe74af9f064176d7ec2f5e7552..d74a5c2a46b62397eed90caf153194eafb9bdb9f 100644 --- a/internal/regofunc/task_test.go +++ b/internal/regofunc/task_test.go @@ -85,3 +85,76 @@ func TestTaskFuncs_CreateTask(t *testing.T) { } } } + +func TestTaskFuncs_CreateTaskList(t *testing.T) { + tests := []struct { + name string + input map[string]interface{} + taskHandler func(w http.ResponseWriter, r *http.Request) + + response map[string]interface{} + errtext string + }{ + { + name: "taskList not found", + input: map[string]interface{}{"test": 123}, + taskHandler: func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + _, _ = w.Write([]byte(`{"error":"taskList not found"}`)) + }, + errtext: "tasklist.create: unexpected response code: 404", + }, + { + name: "task service returns error", + input: map[string]interface{}{"test": 123}, + taskHandler: func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + }, + errtext: "tasklist.create: unexpected response code: 500", + }, + { + name: "task service returns invalid JSON response", + input: nil, + taskHandler: func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("boom")) + }, + response: nil, + errtext: "tasklist.create: invalid character", + }, + { + name: "taskList is created successfully", + input: nil, + taskHandler: func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"taskListID":"hello"}`)) + }, + response: map[string]interface{}{"taskListID": "hello"}, + errtext: "", + }, + } + + for _, test := range tests { + srv := httptest.NewServer(http.HandlerFunc(test.taskHandler)) + taskFuncs := regofunc.NewTaskFuncs(srv.URL, http.DefaultClient) + + query, err := rego.New( + rego.Query(`tasklist.create("taskListName", input)`), + rego.Function2(taskFuncs.CreateTaskListFunc()), + rego.StrictBuiltinErrors(true), + ).PrepareForEval(context.Background()) + assert.NoError(t, err) + + resultSet, err := query.Eval(context.Background(), rego.EvalInput(test.input)) + if test.errtext != "" { + assert.Nil(t, resultSet) + assert.Error(t, err) + assert.Contains(t, err.Error(), test.errtext) + } else { + assert.NoError(t, err) + assert.NotEmpty(t, resultSet) + assert.NotEmpty(t, resultSet[0].Expressions) + assert.Equal(t, test.response, resultSet[0].Expressions[0].Value) + } + } +}