diff --git a/cmd/policy/main.go b/cmd/policy/main.go index 982e3625abf22dc5ffb06c6c0f05bd6c1a69193b..3fd6a720ff8ea2691b8d01cc3aaa04e7bb18e77e 100644 --- a/cmd/policy/main.go +++ b/cmd/policy/main.go @@ -84,16 +84,17 @@ func main() { cacheFuncs := regofunc.NewCacheFuncs(cfg.Cache.Addr, httpClient) didResolverFuncs := regofunc.NewDIDResolverFuncs(cfg.DIDResolver.Addr, httpClient) taskFuncs := regofunc.NewTaskFuncs(cfg.Task.Addr, httpClient) - keysFuncs := regofunc.NewPubkeyFuncs(cfg.Signer.Addr, httpClient) + signerFuncs := regofunc.NewSignerFuncs(cfg.Signer.Addr, httpClient) regofunc.Register("cacheGet", rego.Function3(cacheFuncs.CacheGetFunc())) regofunc.Register("cacheSet", rego.Function4(cacheFuncs.CacheSetFunc())) regofunc.Register("didResolve", rego.Function1(didResolverFuncs.ResolveFunc())) regofunc.Register("taskCreate", rego.Function2(taskFuncs.CreateTaskFunc())) regofunc.Register("taskListCreate", rego.Function2(taskFuncs.CreateTaskListFunc())) - regofunc.Register("getKey", rego.Function1(keysFuncs.GetKeyFunc())) - regofunc.Register("getAllKeys", rego.FunctionDyn(keysFuncs.GetAllKeysFunc())) - regofunc.Register("getAllKeys", rego.FunctionDyn(keysFuncs.GetAllKeysFunc())) - regofunc.Register("issuer", rego.FunctionDyn(keysFuncs.IssuerDID())) + regofunc.Register("getKey", rego.Function1(signerFuncs.GetKeyFunc())) + regofunc.Register("getAllKeys", rego.FunctionDyn(signerFuncs.GetAllKeysFunc())) + regofunc.Register("issuer", rego.FunctionDyn(signerFuncs.IssuerDID())) + regofunc.Register("createProof", rego.Function1(signerFuncs.CreateProof())) + regofunc.Register("verifyProof", rego.Function1(signerFuncs.VerifyProof())) } // subscribe the cache for policy data changes diff --git a/internal/regofunc/pubkeys.go b/internal/regofunc/pubkeys.go deleted file mode 100644 index c9873c6b9c246df6205b2d0d781a60a8de5e1617..0000000000000000000000000000000000000000 --- a/internal/regofunc/pubkeys.go +++ /dev/null @@ -1,141 +0,0 @@ -package regofunc - -import ( - "fmt" - "net/http" - "net/url" - "strings" - - "github.com/open-policy-agent/opa/ast" - "github.com/open-policy-agent/opa/rego" - "github.com/open-policy-agent/opa/types" -) - -type PubkeyFuncs struct { - signerAddr string - httpClient *http.Client -} - -func NewPubkeyFuncs(signerAddr string, httpClient *http.Client) *PubkeyFuncs { - return &PubkeyFuncs{ - signerAddr: signerAddr, - httpClient: httpClient, - } -} - -func (pf *PubkeyFuncs) GetKeyFunc() (*rego.Function, rego.Builtin1) { - return ®o.Function{ - Name: "keys.get", - Decl: types.NewFunction(types.Args(types.S), types.A), - Memoize: true, - }, - func(bctx rego.BuiltinContext, keyname *ast.Term) (*ast.Term, error) { - var key string - if err := ast.As(keyname.Value, &key); err != nil { - return nil, fmt.Errorf("invalid keyname: %s", err) - } - - if strings.TrimSpace(key) == "" { - return nil, fmt.Errorf("empty keyname") - } - - uri, err := url.ParseRequestURI(pf.signerAddr + "/v1/keys/" + key) - if err != nil { - return nil, err - } - - req, err := http.NewRequest("GET", uri.String(), nil) - if err != nil { - return nil, err - } - - resp, err := pf.httpClient.Do(req.WithContext(bctx.Context)) - if err != nil { - return nil, err - } - defer resp.Body.Close() // nolint:errcheck - - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("unexpected response from signer: %s", resp.Status) - } - - v, err := ast.ValueFromReader(resp.Body) - if err != nil { - return nil, err - } - - return ast.NewTerm(v), nil - } -} - -func (pf *PubkeyFuncs) GetAllKeysFunc() (*rego.Function, rego.BuiltinDyn) { - return ®o.Function{ - Name: "keys.getAll", - Decl: types.NewFunction(nil, types.A), - Memoize: true, - }, - func(bctx rego.BuiltinContext, terms []*ast.Term) (*ast.Term, error) { - uri, err := url.ParseRequestURI(pf.signerAddr + "/v1/keys") - if err != nil { - return nil, err - } - - req, err := http.NewRequest("GET", uri.String(), nil) - if err != nil { - return nil, err - } - - resp, err := pf.httpClient.Do(req.WithContext(bctx.Context)) - if err != nil { - return nil, err - } - defer resp.Body.Close() // nolint:errcheck - - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("unexpected response from signer: %s", resp.Status) - } - - v, err := ast.ValueFromReader(resp.Body) - if err != nil { - return nil, err - } - - return ast.NewTerm(v), nil - } -} - -func (pf *PubkeyFuncs) IssuerDID() (*rego.Function, rego.BuiltinDyn) { - return ®o.Function{ - Name: "issuer", - Decl: types.NewFunction(nil, types.A), - Memoize: true, - }, - func(bctx rego.BuiltinContext, terms []*ast.Term) (*ast.Term, error) { - uri, err := url.ParseRequestURI(pf.signerAddr + "/v1/issuerDID") - if err != nil { - return nil, err - } - - req, err := http.NewRequest("GET", uri.String(), nil) - if err != nil { - return nil, err - } - - resp, err := pf.httpClient.Do(req.WithContext(bctx.Context)) - if err != nil { - return nil, err - } - defer resp.Body.Close() // nolint:errcheck - - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("unexpected response from signer: %s", resp.Status) - } - - v, err := ast.ValueFromReader(resp.Body) - if err != nil { - return nil, err - } - - return ast.NewTerm(v), nil - } -} diff --git a/internal/regofunc/signer.go b/internal/regofunc/signer.go new file mode 100644 index 0000000000000000000000000000000000000000..8eedb65851fb2377c9fd6ca698860ddb71e4761b --- /dev/null +++ b/internal/regofunc/signer.go @@ -0,0 +1,275 @@ +package regofunc + +import ( + "bytes" + "encoding/json" + "fmt" + "net/http" + "net/url" + "strings" + + "github.com/open-policy-agent/opa/ast" + "github.com/open-policy-agent/opa/rego" + "github.com/open-policy-agent/opa/types" +) + +type SignerFuncs struct { + signerAddr string + httpClient *http.Client +} + +func NewSignerFuncs(signerAddr string, httpClient *http.Client) *SignerFuncs { + return &SignerFuncs{ + signerAddr: signerAddr, + httpClient: httpClient, + } +} + +func (sf *SignerFuncs) GetKeyFunc() (*rego.Function, rego.Builtin1) { + return ®o.Function{ + Name: "keys.get", + Decl: types.NewFunction(types.Args(types.S), types.A), + Memoize: true, + }, + func(bctx rego.BuiltinContext, keyname *ast.Term) (*ast.Term, error) { + var key string + if err := ast.As(keyname.Value, &key); err != nil { + return nil, fmt.Errorf("invalid keyname: %s", err) + } + + if strings.TrimSpace(key) == "" { + return nil, fmt.Errorf("empty keyname") + } + + uri, err := url.ParseRequestURI(sf.signerAddr + "/v1/keys/" + key) + if err != nil { + return nil, err + } + + req, err := http.NewRequest("GET", uri.String(), nil) + if err != nil { + return nil, err + } + + resp, err := sf.httpClient.Do(req.WithContext(bctx.Context)) + if err != nil { + return nil, err + } + defer resp.Body.Close() // nolint:errcheck + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("unexpected response from signer: %s", resp.Status) + } + + v, err := ast.ValueFromReader(resp.Body) + if err != nil { + return nil, err + } + + return ast.NewTerm(v), nil + } +} + +func (sf *SignerFuncs) GetAllKeysFunc() (*rego.Function, rego.BuiltinDyn) { + return ®o.Function{ + Name: "keys.getAll", + Decl: types.NewFunction(nil, types.A), + Memoize: true, + }, + func(bctx rego.BuiltinContext, terms []*ast.Term) (*ast.Term, error) { + uri, err := url.ParseRequestURI(sf.signerAddr + "/v1/keys") + if err != nil { + return nil, err + } + + req, err := http.NewRequest("GET", uri.String(), nil) + if err != nil { + return nil, err + } + + resp, err := sf.httpClient.Do(req.WithContext(bctx.Context)) + if err != nil { + return nil, err + } + defer resp.Body.Close() // nolint:errcheck + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("unexpected response from signer: %s", resp.Status) + } + + v, err := ast.ValueFromReader(resp.Body) + if err != nil { + return nil, err + } + + return ast.NewTerm(v), nil + } +} + +func (sf *SignerFuncs) IssuerDID() (*rego.Function, rego.BuiltinDyn) { + return ®o.Function{ + Name: "issuer", + Decl: types.NewFunction(nil, types.A), + Memoize: true, + }, + func(bctx rego.BuiltinContext, terms []*ast.Term) (*ast.Term, error) { + uri, err := url.ParseRequestURI(sf.signerAddr + "/v1/issuerDID") + if err != nil { + return nil, err + } + + req, err := http.NewRequest("GET", uri.String(), nil) + if err != nil { + return nil, err + } + + resp, err := sf.httpClient.Do(req.WithContext(bctx.Context)) + if err != nil { + return nil, err + } + defer resp.Body.Close() // nolint:errcheck + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("unexpected response from signer: %s", resp.Status) + } + + v, err := ast.ValueFromReader(resp.Body) + if err != nil { + return nil, err + } + + return ast.NewTerm(v), nil + } +} + +func (sf *SignerFuncs) CreateProof() (*rego.Function, rego.Builtin1) { + return ®o.Function{ + Name: "proof.create", + Decl: types.NewFunction(types.Args(types.S), types.A), + Memoize: true, + }, + func(bctx rego.BuiltinContext, credential *ast.Term) (*ast.Term, error) { + // cred represents verifiable credential or presentation + var cred map[string]interface{} + if err := ast.As(credential.Value, &cred); err != nil { + return nil, fmt.Errorf("invalid credential: %s", err) + } + + if cred["type"] == nil { + return nil, fmt.Errorf("credential data does not specify type: must be VerifiablePresentation or VerifiableCredential") + } + + credType, ok := cred["type"].(string) + if !ok { + return nil, fmt.Errorf("invalid credential type, string is expected") + } + + var createProofPath string + switch credType { + case "VerifiableCredential": + createProofPath = "/v1/credential/proof" + case "VerifiablePresentation": + createProofPath = "/v1/presentation/proof" + default: + return nil, fmt.Errorf("unknown credential type: %q", credType) + } + + jsonCred, err := json.Marshal(cred) + if err != nil { + return nil, err + } + + req, err := http.NewRequest("POST", sf.signerAddr+createProofPath, bytes.NewReader(jsonCred)) + if err != nil { + return nil, err + } + + resp, err := sf.httpClient.Do(req.WithContext(bctx.Context)) + if err != nil { + return nil, err + } + defer resp.Body.Close() // nolint:errcheck + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("unexpected response from signer: %d", resp.StatusCode) + } + + v, err := ast.ValueFromReader(resp.Body) + if err != nil { + return nil, err + } + + return ast.NewTerm(v), nil + } +} + +func (sf *SignerFuncs) VerifyProof() (*rego.Function, rego.Builtin1) { + return ®o.Function{ + Name: "proof.verify", + Decl: types.NewFunction(types.Args(types.S), types.A), + Memoize: true, + }, + func(bctx rego.BuiltinContext, credential *ast.Term) (*ast.Term, error) { + // cred represents verifiable credential or presentation + var cred map[string]interface{} + if err := ast.As(credential.Value, &cred); err != nil { + return nil, fmt.Errorf("invalid credential: %s", err) + } + + if cred["type"] == nil { + return nil, fmt.Errorf("credential data does not specify type: must be VerifiablePresentation or VerifiableCredential") + } + + credType, ok := cred["type"].(string) + if !ok { + return nil, fmt.Errorf("invalid credential type, string is expected") + } + + if cred["proof"] == nil { + return nil, fmt.Errorf("credential data does contain proof section") + } + + var verifyProofPath string + switch credType { + case "VerifiableCredential": + verifyProofPath = "/v1/credential/verify" + case "VerifiablePresentation": + verifyProofPath = "/v1/presentation/verify" + default: + return nil, fmt.Errorf("unknown credential type: %q", credType) + } + + jsonCred, err := json.Marshal(cred) + if err != nil { + return nil, err + } + + req, err := http.NewRequest("POST", sf.signerAddr+verifyProofPath, bytes.NewReader(jsonCred)) + if err != nil { + return nil, err + } + + resp, err := sf.httpClient.Do(req.WithContext(bctx.Context)) + if err != nil { + return nil, err + } + defer resp.Body.Close() // nolint:errcheck + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("unexpected response from signer: %d", resp.StatusCode) + } + + var result struct { + Valid bool `json:"valid"` + } + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return nil, fmt.Errorf("failed to decode response from signer: %v", err) + } + + if !result.Valid { + return nil, fmt.Errorf("proof is invalid") + } + + return ast.NewTerm(ast.Boolean(true)), nil + } +} diff --git a/internal/regofunc/pubkeys_test.go b/internal/regofunc/signer_test.go similarity index 91% rename from internal/regofunc/pubkeys_test.go rename to internal/regofunc/signer_test.go index 11a7a7c5c084c99522ad4bb4febd865f980e0c49..77fad987b0b6e8b0bf78f910a8ff48d7374c9d9a 100644 --- a/internal/regofunc/pubkeys_test.go +++ b/internal/regofunc/signer_test.go @@ -21,7 +21,7 @@ func TestGetKeyFunc(t *testing.T) { })) defer signerSrv.Close() - keysFuncs := regofunc.NewPubkeyFuncs(signerSrv.URL, http.DefaultClient) + keysFuncs := regofunc.NewSignerFuncs(signerSrv.URL, http.DefaultClient) r := rego.New( rego.Query(`keys.get("key1")`), rego.Function1(keysFuncs.GetKeyFunc()), @@ -41,7 +41,7 @@ func TestGetKeyFuncError(t *testing.T) { })) defer signerSrv.Close() - keysFuncs := regofunc.NewPubkeyFuncs(signerSrv.URL, http.DefaultClient) + keysFuncs := regofunc.NewSignerFuncs(signerSrv.URL, http.DefaultClient) r := rego.New( rego.Query(`keys.get("key1")`), rego.Function1(keysFuncs.GetKeyFunc()), @@ -62,7 +62,7 @@ func TestGetAllKeysFunc(t *testing.T) { })) defer signerSrv.Close() - keysFuncs := regofunc.NewPubkeyFuncs(signerSrv.URL, http.DefaultClient) + keysFuncs := regofunc.NewSignerFuncs(signerSrv.URL, http.DefaultClient) r := rego.New( rego.Query(`keys.getAll()`), rego.FunctionDyn(keysFuncs.GetAllKeysFunc()), @@ -83,7 +83,7 @@ func TestIssuerDID(t *testing.T) { })) defer signerSrv.Close() - keysFuncs := regofunc.NewPubkeyFuncs(signerSrv.URL, http.DefaultClient) + keysFuncs := regofunc.NewSignerFuncs(signerSrv.URL, http.DefaultClient) r := rego.New( rego.Query(`issuer()`), rego.FunctionDyn(keysFuncs.IssuerDID()),