diff --git a/cmd/policy/main.go b/cmd/policy/main.go index 982e3625abf22dc5ffb06c6c0f05bd6c1a69193b..a2e7dd1a6c0d2b8d41bd24a57b5906ad6774e21a 100644 --- a/cmd/policy/main.go +++ b/cmd/policy/main.go @@ -85,6 +85,7 @@ func main() { didResolverFuncs := regofunc.NewDIDResolverFuncs(cfg.DIDResolver.Addr, httpClient) taskFuncs := regofunc.NewTaskFuncs(cfg.Task.Addr, httpClient) keysFuncs := regofunc.NewPubkeyFuncs(cfg.Signer.Addr, httpClient) + ocmFuncs := regofunc.NewOcmFuncs(cfg.Ocm.Addr, httpClient) regofunc.Register("cacheGet", rego.Function3(cacheFuncs.CacheGetFunc())) regofunc.Register("cacheSet", rego.Function4(cacheFuncs.CacheSetFunc())) regofunc.Register("didResolve", rego.Function1(didResolverFuncs.ResolveFunc())) @@ -92,8 +93,9 @@ func main() { regofunc.Register("taskListCreate", rego.Function2(taskFuncs.CreateTaskListFunc())) regofunc.Register("getKey", rego.Function1(keysFuncs.GetKeyFunc())) regofunc.Register("getAllKeys", rego.FunctionDyn(keysFuncs.GetAllKeysFunc())) - regofunc.Register("getAllKeys", rego.FunctionDyn(keysFuncs.GetAllKeysFunc())) regofunc.Register("issuer", rego.FunctionDyn(keysFuncs.IssuerDID())) + regofunc.Register("ocmLoginProofInvitation", rego.Function1(ocmFuncs.GetLoginProofInvitation())) + regofunc.Register("ocmLoginProofResult", rego.Function1(ocmFuncs.GetLoginProofResult())) } // subscribe the cache for policy data changes diff --git a/internal/config/config.go b/internal/config/config.go index 1b89ed85ebc4b21141a8e598b394972ab7abb1df..79b94931a9c9a2c891e995cb33d30deb5d3f0fcb 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -10,6 +10,7 @@ type Config struct { Signer signerConfig DIDResolver didResolverConfig Metrics metricsConfig + Ocm ocmConfig LogLevel string `envconfig:"LOG_LEVEL" default:"INFO"` } @@ -49,3 +50,7 @@ type mongoConfig struct { type metricsConfig struct { Addr string `envconfig:"METRICS_ADDR" default:":2112"` } + +type ocmConfig struct { + Addr string `envconfig:"OCM_ADDR" required:"true"` +} diff --git a/internal/regofunc/ocm.go b/internal/regofunc/ocm.go new file mode 100644 index 0000000000000000000000000000000000000000..7cc531e3b4fd2a0fbdc969fbd9cfb98d7a8560d0 --- /dev/null +++ b/internal/regofunc/ocm.go @@ -0,0 +1,85 @@ +package regofunc + +import ( + "fmt" + "net/http" + + "github.com/open-policy-agent/opa/ast" + "github.com/open-policy-agent/opa/rego" + "github.com/open-policy-agent/opa/types" + + "code.vereign.com/gaiax/tsa/golib/ocm" +) + +type OcmFuncs struct { + client *ocm.Client +} + +func NewOcmFuncs(ocmAddr string, httpClient *http.Client) *OcmFuncs { + ocmClient := ocm.New(ocmAddr, ocm.WithHTTPClient(httpClient)) + + return &OcmFuncs{client: ocmClient} +} + +func (of *OcmFuncs) GetLoginProofInvitation() (*rego.Function, rego.Builtin1) { + return ®o.Function{ + Name: "ocm.getLoginProofInvitation", + Decl: types.NewFunction(types.Args(types.A), types.A), + Memoize: true, + }, + func(bctx rego.BuiltinContext, types *ast.Term) (*ast.Term, error) { + var credTypes []string + + if err := ast.As(types.Value, &credTypes); err != nil { + return nil, fmt.Errorf("invalid credential types array: %s", err) + } + + res, err := of.client.GetLoginProofInvitation(bctx.Context, credTypes) + if err != nil { + return nil, err + } + + type result struct { + Link string `json:"link"` + RequestId string `json:"requestId"` + } + var val ast.Value + val, err = ast.InterfaceToValue(result{ + Link: res.Data.PresentationMessage, + RequestId: res.Data.PresentationID, + }) + if err != nil { + return nil, err + } + + return ast.NewTerm(val), nil + } +} + +func (of *OcmFuncs) GetLoginProofResult() (*rego.Function, rego.Builtin1) { + return ®o.Function{ + Name: "ocm.getLoginProofResult", + Decl: types.NewFunction(types.Args(types.S), types.A), + Memoize: true, + }, + func(bctx rego.BuiltinContext, id *ast.Term) (*ast.Term, error) { + var presentationId string + + if err := ast.As(id.Value, &presentationId); err != nil { + return nil, fmt.Errorf("invalid presentationId: %s", err) + } + + res, err := of.client.GetLoginProofResult(bctx.Context, presentationId) + if err != nil { + return nil, err + } + + var val ast.Value + val, err = ast.InterfaceToValue(res.Data.Claims) + if err != nil { + return nil, err + } + + return ast.NewTerm(val), nil + } +} diff --git a/internal/regofunc/ocm_test.go b/internal/regofunc/ocm_test.go new file mode 100644 index 0000000000000000000000000000000000000000..63d022e54ab9102a2690357aa73be72285136cbc --- /dev/null +++ b/internal/regofunc/ocm_test.go @@ -0,0 +1,101 @@ +package regofunc_test + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "testing" + + "github.com/open-policy-agent/opa/rego" + "github.com/stretchr/testify/assert" + + "code.vereign.com/gaiax/tsa/policy/internal/regofunc" +) + +func TestGetLoginProofInvitationSuccess(t *testing.T) { + expected := `{"link":"https://ocm:443/ocm/didcomm/?d_m=eyJAdHlwZSI","requestId":"2cf01406-b15f-4960-a6a7-7bc62cd37a3c"}` + ocmResponse := `{ + "statusCode": 201, + "message": "Presentation request send successfully", + "data": { + "presentationId": "2cf01406-b15f-4960-a6a7-7bc62cd37a3c", + "presentationMessage": "https://ocm:443/ocm/didcomm/?d_m=eyJAdHlwZSI" + } + }` + + ocmSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = fmt.Fprint(w, ocmResponse) + })) + defer ocmSrv.Close() + + ocmFuncs := regofunc.NewOcmFuncs(ocmSrv.URL, http.DefaultClient) + + r := rego.New( + rego.Query(`ocm.getLoginProofInvitation(["openid", "profile"])`), + rego.Function1(ocmFuncs.GetLoginProofInvitation()), + rego.StrictBuiltinErrors(true), + ) + + resultSet, err := r.Eval(context.Background()) + assert.NoError(t, err) + + resultBytes, err := json.Marshal(resultSet[0].Expressions[0].Value) + assert.NoError(t, err) + assert.Equal(t, expected, string(resultBytes)) +} + +func TestGetLoginProofInvitationErr(t *testing.T) { + ocmSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = fmt.Fprint(w, `{"key":"value"}`) + })) + defer ocmSrv.Close() + + ocmFuncs := regofunc.NewOcmFuncs(ocmSrv.URL, http.DefaultClient) + + r := rego.New( + rego.Query(`ocm.getLoginProofInvitation("openid")`), + rego.Function1(ocmFuncs.GetLoginProofInvitation()), + rego.StrictBuiltinErrors(true), + ) + + resultSet, err := r.Eval(context.Background()) + assert.Error(t, err) + assert.Empty(t, resultSet) + assert.Contains(t, err.Error(), "cannot unmarshal string into Go value of type []string") +} + +func TestGetLoginProofResult(t *testing.T) { + expected := `{"family_name":"Doe","name":"John"}` + ocmResponse := `{ + "statusCode": 200, + "message": "Proof presentation fetch successfully", + "data": { + "credentialSubject": { + "name":"John", + "family_name":"Doe" + } + } + }` + + ocmSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = fmt.Fprint(w, ocmResponse) + })) + defer ocmSrv.Close() + + ocmFuncs := regofunc.NewOcmFuncs(ocmSrv.URL, http.DefaultClient) + + r := rego.New( + rego.Query(`ocm.getLoginProofResult("2cf01406-b15f-4960-a6a7-7bc62cd37a3c")`), + rego.Function1(ocmFuncs.GetLoginProofResult()), + rego.StrictBuiltinErrors(true), + ) + + resultSet, err := r.Eval(context.Background()) + assert.NoError(t, err) + + resultBytes, err := json.Marshal(resultSet[0].Expressions[0].Value) + assert.NoError(t, err) + assert.Equal(t, expected, string(resultBytes)) +}