From 4b6bfc7abffba76229bd3fdfad991f0104b42f35 Mon Sep 17 00:00:00 2001
From: Yordan Kinkov <yordan.kinkov@vereign.com>
Date: Mon, 25 Jul 2022 14:17:17 +0300
Subject: [PATCH] Handle "scope to credential type" map in
 GetLoginProofInvitation func

---
 cmd/policy/main.go            |  2 +-
 internal/regofunc/ocm.go      | 24 ++++++++++++++++++------
 internal/regofunc/ocm_test.go | 23 ++++++++++++++++++-----
 3 files changed, 37 insertions(+), 12 deletions(-)

diff --git a/cmd/policy/main.go b/cmd/policy/main.go
index 9bdec2b0..9791cdc8 100644
--- a/cmd/policy/main.go
+++ b/cmd/policy/main.go
@@ -96,7 +96,7 @@ func main() {
 		regofunc.Register("issuer", rego.FunctionDyn(signerFuncs.IssuerDID()))
 		regofunc.Register("createProof", rego.Function1(signerFuncs.CreateProof()))
 		regofunc.Register("verifyProof", rego.Function1(signerFuncs.VerifyProof()))
-		regofunc.Register("ocmLoginProofInvitation", rego.Function1(ocmFuncs.GetLoginProofInvitation()))
+		regofunc.Register("ocmLoginProofInvitation", rego.Function2(ocmFuncs.GetLoginProofInvitation()))
 		regofunc.Register("ocmLoginProofResult", rego.Function1(ocmFuncs.GetLoginProofResult()))
 	}
 
diff --git a/internal/regofunc/ocm.go b/internal/regofunc/ocm.go
index b32f9922..dc2553b3 100644
--- a/internal/regofunc/ocm.go
+++ b/internal/regofunc/ocm.go
@@ -21,17 +21,29 @@ func NewOcmFuncs(ocmAddr string, httpClient *http.Client) *OcmFuncs {
 	return &OcmFuncs{client: ocmClient}
 }
 
-func (of *OcmFuncs) GetLoginProofInvitation() (*rego.Function, rego.Builtin1) {
+func (of *OcmFuncs) GetLoginProofInvitation() (*rego.Function, rego.Builtin2) {
 	return &rego.Function{
 			Name:    "ocm.getLoginProofInvitation",
-			Decl:    types.NewFunction(types.Args(types.A), types.A),
+			Decl:    types.NewFunction(types.Args(types.A, types.A), types.A),
 			Memoize: true,
 		},
-		func(bctx rego.BuiltinContext, types *ast.Term) (*ast.Term, error) {
-			var credTypes []string
+		func(bctx rego.BuiltinContext, rScopes *ast.Term, scopesMap *ast.Term) (*ast.Term, error) {
+			var scopes []string
+			var scopeToType map[string]string
+
+			if err := ast.As(rScopes.Value, &scopes); err != nil {
+				return nil, fmt.Errorf("invalid scopes array: %s", err)
+			} else if err = ast.As(scopesMap.Value, &scopeToType); err != nil {
+				return nil, fmt.Errorf("invalid scope to credential type map: %s", err)
+			}
 
-			if err := ast.As(types.Value, &credTypes); err != nil {
-				return nil, fmt.Errorf("invalid credential types array: %s", err)
+			var credTypes []string
+			for _, scope := range scopes {
+				credType, ok := scopeToType[scope]
+				if !ok {
+					return nil, fmt.Errorf("scope not found in scope to type map: %s", scope)
+				}
+				credTypes = append(credTypes, credType)
 			}
 
 			res, err := of.client.GetLoginProofInvitation(bctx.Context, credTypes)
diff --git a/internal/regofunc/ocm_test.go b/internal/regofunc/ocm_test.go
index 26de0b04..6707652c 100644
--- a/internal/regofunc/ocm_test.go
+++ b/internal/regofunc/ocm_test.go
@@ -33,8 +33,8 @@ func TestGetLoginProofInvitationSuccess(t *testing.T) {
 	ocmFuncs := regofunc.NewOcmFuncs(ocmSrv.URL, http.DefaultClient)
 
 	r := rego.New(
-		rego.Query(`ocm.getLoginProofInvitation(["openid", "profile"])`),
-		rego.Function1(ocmFuncs.GetLoginProofInvitation()),
+		rego.Query(`ocm.getLoginProofInvitation(["openid", "profile"], {"openid": "credType1", "profile": "credType2"})`),
+		rego.Function2(ocmFuncs.GetLoginProofInvitation()),
 		rego.StrictBuiltinErrors(true),
 	)
 
@@ -54,16 +54,29 @@ func TestGetLoginProofInvitationErr(t *testing.T) {
 
 	ocmFuncs := regofunc.NewOcmFuncs(ocmSrv.URL, http.DefaultClient)
 
+	// invalid scopes array
 	r := rego.New(
-		rego.Query(`ocm.getLoginProofInvitation("openid")`),
-		rego.Function1(ocmFuncs.GetLoginProofInvitation()),
+		rego.Query(`ocm.getLoginProofInvitation("openid", {"openid": "credType1", "profile": "credType2"})`),
+		rego.Function2(ocmFuncs.GetLoginProofInvitation()),
 		rego.StrictBuiltinErrors(true),
 	)
 
 	resultSet, err := r.Eval(context.Background())
 	assert.Error(t, err)
 	assert.Empty(t, resultSet)
-	assert.Contains(t, err.Error(), "cannot unmarshal string into Go value of type []string")
+	assert.Contains(t, err.Error(), "invalid scopes array")
+
+	// invalid "scope to credential type" map
+	r = rego.New(
+		rego.Query(`ocm.getLoginProofInvitation(["openid", "profile"], "map")`),
+		rego.Function2(ocmFuncs.GetLoginProofInvitation()),
+		rego.StrictBuiltinErrors(true),
+	)
+
+	resultSet, err = r.Eval(context.Background())
+	assert.Error(t, err)
+	assert.Empty(t, resultSet)
+	assert.Contains(t, err.Error(), "invalid scope to credential type map")
 }
 
 func TestGetLoginProofResult(t *testing.T) {
-- 
GitLab