diff --git a/internal/clients/vault/client.go b/internal/clients/vault/client.go index d5e199372cbb10d8a5091175a826c946c16b2e6b..bf2b160f18c56426b90619508ab33903c6710f11 100644 --- a/internal/clients/vault/client.go +++ b/internal/clients/vault/client.go @@ -1,6 +1,7 @@ package vault import ( + "context" "encoding/base64" "encoding/json" "fmt" @@ -58,27 +59,23 @@ func (c *Client) WithKey(key string) signer.Signer { } // Key tries to fetch a key with the given name from the Vault. -func (c *Client) Key(key string) (*signer.SignKey, error) { +func (c *Client) Key(ctx context.Context, key string) (*signer.SignKey, error) { req := c.client.NewRequest(http.MethodGet, pathKeys+key) - res, err := c.client.RawRequest(req) + res, err := c.client.RawRequestWithContext(ctx, req) if err != nil { return nil, errors.New(errors.GetKind(res.StatusCode), err) } defer res.Body.Close() - var response struct { - Data struct { - Name string `json:"name"` - Type string `json:"type"` - } `json:"data"` - } + var response getKeyResponse if err := json.NewDecoder(res.Body).Decode(&response); err != nil { return nil, err } return &signer.SignKey{ - Name: response.Data.Name, - Type: response.Data.Type, + Name: response.Data.Name, + Type: response.Data.Type, + PublicKey: response.lastPublicKeyVersion(), }, nil } diff --git a/internal/clients/vault/client_test.go b/internal/clients/vault/client_test.go index 66258b01eaa6589094ecec26cf9b56358f5a8919..d0fc41e473dab450c53000374c766dd1810777bf 100644 --- a/internal/clients/vault/client_test.go +++ b/internal/clients/vault/client_test.go @@ -1,6 +1,7 @@ package vault_test import ( + "context" "net/http" "net/http/httptest" "testing" @@ -68,7 +69,7 @@ func TestClient_Key(t *testing.T) { client, err := vault.New(vaultsrv.URL, "token", false, http.DefaultClient) assert.NoError(t, err) - res, err := client.Key(test.key) + res, err := client.Key(context.Background(), test.key) if err != nil { assert.Nil(t, res) diff --git a/internal/clients/vault/transport.go b/internal/clients/vault/transport.go new file mode 100644 index 0000000000000000000000000000000000000000..612f13e433120afbb3b05c15ba28e09502d89988 --- /dev/null +++ b/internal/clients/vault/transport.go @@ -0,0 +1,33 @@ +package vault + +import "strconv" + +type getKeyResponse struct { + Data struct { + Name string `json:"name"` + Type string `json:"type"` + Keys map[string]struct { + PublicKey string `json:"public_key"` + } `json:"keys"` + } `json:"data"` +} + +// lastPublicKeyVersion iterates the map with key versions and +// returns the latest public key. +func (r *getKeyResponse) lastPublicKeyVersion() string { + var lastVerString string + var lastVerInt int + for ver := range r.Data.Keys { + verInt, err := strconv.Atoi(ver) + if err != nil { + continue + } + + if verInt > lastVerInt { + lastVerInt = verInt + lastVerString = ver + } + } + + return r.Data.Keys[lastVerString].PublicKey +} diff --git a/internal/clients/vault/transport_test.go b/internal/clients/vault/transport_test.go new file mode 100644 index 0000000000000000000000000000000000000000..8b5a82e64ceaba9d2a89d74dc1ca7e28b4d4756b --- /dev/null +++ b/internal/clients/vault/transport_test.go @@ -0,0 +1,72 @@ +package vault + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_lastPublicKeyVersion(t *testing.T) { + tests := []struct { + name string + keys map[string]struct { + PublicKey string `json:"public_key"` + } + key string + }{ + { + name: "no keys in response", + }, + { + name: "one key in response", + keys: map[string]struct { + PublicKey string `json:"public_key"` + }{ + "1": {PublicKey: "key1"}, + }, + key: "key1", + }, + { + name: "two keys in response", + keys: map[string]struct { + PublicKey string `json:"public_key"` + }{ + "2": {PublicKey: "key2"}, + "1": {PublicKey: "key1"}, + }, + key: "key2", + }, + { + name: "three keys in response", + keys: map[string]struct { + PublicKey string `json:"public_key"` + }{ + "2": {PublicKey: "key2"}, + "1": {PublicKey: "key1"}, + "4": {PublicKey: "key4"}, + }, + key: "key4", + }, + { + name: "four keys in response", + keys: map[string]struct { + PublicKey string `json:"public_key"` + }{ + "2": {PublicKey: "key2"}, + "8": {PublicKey: "key8"}, + "1": {PublicKey: "key1"}, + "4": {PublicKey: "key4"}, + }, + key: "key8", + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + response := &getKeyResponse{} + response.Data.Keys = test.keys + key := response.lastPublicKeyVersion() + assert.Equal(t, test.key, key) + }) + } +} diff --git a/internal/service/signer/service.go b/internal/service/signer/service.go index 06755754577a21cf4db4c77c181f216d651783c6..6e19f0b1b22b1efb3a4fc9f89b4cdf2ffdb3777c 100644 --- a/internal/service/signer/service.go +++ b/internal/service/signer/service.go @@ -21,12 +21,13 @@ import ( //go:generate counterfeiter . Signer type SignKey struct { - Name string - Type string + Name string `json:"name"` + Type string `json:"type"` + PublicKey string `json:"public_key,omitempty"` } type Signer interface { - Key(key string) (*SignKey, error) + Key(ctx context.Context, key string) (*SignKey, error) Sign(data []byte) ([]byte, error) WithKey(key string) Signer } @@ -48,6 +49,17 @@ func New(signer Signer, defaultKey string, httpClient *http.Client, logger *zap. } } +// GetKey returns a key from Vault or OCM. +func (s *Service) GetKey(ctx context.Context, req *signer.GetKeyRequest) (interface{}, error) { + key, err := s.signer.Key(ctx, req.Key) + if err != nil { + s.logger.Error("error getting key", zap.Error(err)) + return nil, err + } + + return key, nil +} + // CredentialProof adds a proof to a given Verifiable Credential. func (s *Service) CredentialProof(ctx context.Context, req *signer.CredentialProofRequest) (interface{}, error) { logger := s.logger.With(zap.String("operation", "credentialProof")) @@ -63,7 +75,7 @@ func (s *Service) CredentialProof(ctx context.Context, req *signer.CredentialPro keyname = *req.Key } - key, err := s.signer.Key(keyname) + key, err := s.signer.Key(ctx, keyname) if err != nil { logger.Error("error getting signing key", zap.String("key", keyname), zap.Error(err)) return nil, errors.New("error getting signing key", err) @@ -98,7 +110,7 @@ func (s *Service) PresentationProof(ctx context.Context, req *signer.Presentatio keyname = *req.Key } - key, err := s.signer.Key(keyname) + key, err := s.signer.Key(ctx, keyname) if err != nil { logger.Error("error getting signing key", zap.String("key", keyname), zap.Error(err)) return nil, errors.New("error getting signing key", err) diff --git a/internal/service/signer/service_test.go b/internal/service/signer/service_test.go index 9a44fb3b27cf0dc41b44aafb7340c34a9693d682..3c34e982648ec97a1086d226415b7d234e130931 100644 --- a/internal/service/signer/service_test.go +++ b/internal/service/signer/service_test.go @@ -17,6 +17,46 @@ import ( "code.vereign.com/gaiax/tsa/signer/internal/service/signer/signerfakes" ) +func TestService_GetKey(t *testing.T) { + t.Run("signer returns error when getting key", func(t *testing.T) { + signerError := &signerfakes.FakeSigner{ + KeyStub: func(ctx context.Context, key string) (*signer.SignKey, error) { + return nil, errors.New(errors.NotFound, "key not found") + }, + } + + svc := signer.New(signerError, "default key", http.DefaultClient, zap.NewNop()) + result, err := svc.GetKey(context.Background(), &goasigner.GetKeyRequest{Key: "key1"}) + assert.Nil(t, result) + assert.Error(t, err) + e, ok := err.(*errors.Error) + assert.True(t, ok) + assert.Equal(t, errors.NotFound, e.Kind) + }) + + t.Run("signer returns key successfully", func(t *testing.T) { + signerOK := &signerfakes.FakeSigner{ + KeyStub: func(ctx context.Context, key string) (*signer.SignKey, error) { + return &signer.SignKey{ + Name: "keyname", + Type: "ed25519", + PublicKey: "public key", + }, nil + }, + } + + svc := signer.New(signerOK, "default key", http.DefaultClient, zap.NewNop()) + result, err := svc.GetKey(context.Background(), &goasigner.GetKeyRequest{Key: "key1"}) + assert.NotNil(t, result) + assert.NoError(t, err) + assert.Equal(t, &signer.SignKey{ + Name: "keyname", + Type: "ed25519", + PublicKey: "public key", + }, result) + }) +} + func TestService_CredentialProof(t *testing.T) { tests := []struct { name string @@ -68,7 +108,7 @@ func TestService_CredentialProof(t *testing.T) { Credential: []byte(validCredential), }, signer: &signerfakes.FakeSigner{ - KeyStub: func(key string) (*signer.SignKey, error) { + KeyStub: func(ctx context.Context, key string) (*signer.SignKey, error) { return nil, errors.New(errors.NotFound) }, }, @@ -83,7 +123,7 @@ func TestService_CredentialProof(t *testing.T) { Credential: []byte(validCredential), }, signer: &signerfakes.FakeSigner{ - KeyStub: func(key string) (*signer.SignKey, error) { + KeyStub: func(ctx context.Context, key string) (*signer.SignKey, error) { return nil, errors.New(errors.Internal) }, }, @@ -98,7 +138,7 @@ func TestService_CredentialProof(t *testing.T) { Credential: []byte(validCredential), }, signer: &signerfakes.FakeSigner{ - KeyStub: func(key string) (*signer.SignKey, error) { + KeyStub: func(ctx context.Context, key string) (*signer.SignKey, error) { return nil, errors.New(errors.Internal) }, }, @@ -113,7 +153,7 @@ func TestService_CredentialProof(t *testing.T) { Credential: []byte(validCredential), }, signer: &signerfakes.FakeSigner{ - KeyStub: func(key string) (*signer.SignKey, error) { + KeyStub: func(ctx context.Context, key string) (*signer.SignKey, error) { return &signer.SignKey{ Name: "key23", Type: "rsa4096", @@ -131,7 +171,7 @@ func TestService_CredentialProof(t *testing.T) { Credential: []byte(validCredential), }, signer: &signerfakes.FakeSigner{ - KeyStub: func(key string) (*signer.SignKey, error) { + KeyStub: func(ctx context.Context, key string) (*signer.SignKey, error) { return &signer.SignKey{ Name: "key123", Type: "ed25519", @@ -164,7 +204,7 @@ func TestService_CredentialProof(t *testing.T) { Credential: []byte(validCredential), }, signer: &signerfakes.FakeSigner{ - KeyStub: func(key string) (*signer.SignKey, error) { + KeyStub: func(ctx context.Context, key string) (*signer.SignKey, error) { return &signer.SignKey{ Name: "key123", Type: "ecdsa-p256", @@ -273,7 +313,7 @@ func TestService_PresentationProof(t *testing.T) { Presentation: []byte(validPresentation), }, signer: &signerfakes.FakeSigner{ - KeyStub: func(key string) (*signer.SignKey, error) { + KeyStub: func(ctx context.Context, key string) (*signer.SignKey, error) { return nil, errors.New(errors.NotFound) }, }, @@ -288,7 +328,7 @@ func TestService_PresentationProof(t *testing.T) { Presentation: []byte(validPresentation), }, signer: &signerfakes.FakeSigner{ - KeyStub: func(key string) (*signer.SignKey, error) { + KeyStub: func(ctx context.Context, key string) (*signer.SignKey, error) { return nil, errors.New(errors.Internal) }, }, @@ -303,7 +343,7 @@ func TestService_PresentationProof(t *testing.T) { Presentation: []byte(validPresentation), }, signer: &signerfakes.FakeSigner{ - KeyStub: func(key string) (*signer.SignKey, error) { + KeyStub: func(ctx context.Context, key string) (*signer.SignKey, error) { return nil, errors.New(errors.Internal) }, }, @@ -318,7 +358,7 @@ func TestService_PresentationProof(t *testing.T) { Presentation: []byte(validPresentation), }, signer: &signerfakes.FakeSigner{ - KeyStub: func(key string) (*signer.SignKey, error) { + KeyStub: func(ctx context.Context, key string) (*signer.SignKey, error) { return &signer.SignKey{ Name: "key23", Type: "rsa4096", @@ -336,7 +376,7 @@ func TestService_PresentationProof(t *testing.T) { Presentation: []byte(validPresentation), }, signer: &signerfakes.FakeSigner{ - KeyStub: func(key string) (*signer.SignKey, error) { + KeyStub: func(ctx context.Context, key string) (*signer.SignKey, error) { return &signer.SignKey{ Name: "key123", Type: "ed25519", @@ -367,7 +407,7 @@ func TestService_PresentationProof(t *testing.T) { Presentation: []byte(validPresentation), }, signer: &signerfakes.FakeSigner{ - KeyStub: func(key string) (*signer.SignKey, error) { + KeyStub: func(ctx context.Context, key string) (*signer.SignKey, error) { return &signer.SignKey{ Name: "key123", Type: "ecdsa-p256", diff --git a/internal/service/signer/signerfakes/fake_signer.go b/internal/service/signer/signerfakes/fake_signer.go index 5b2b696013317d774fe2aae6793758037054f786..9cd9bfbec59386f657a0a197d127c26462dca3b4 100644 --- a/internal/service/signer/signerfakes/fake_signer.go +++ b/internal/service/signer/signerfakes/fake_signer.go @@ -2,16 +2,18 @@ package signerfakes import ( + "context" "sync" "code.vereign.com/gaiax/tsa/signer/internal/service/signer" ) type FakeSigner struct { - KeyStub func(string) (*signer.SignKey, error) + KeyStub func(context.Context, string) (*signer.SignKey, error) keyMutex sync.RWMutex keyArgsForCall []struct { - arg1 string + arg1 context.Context + arg2 string } keyReturns struct { result1 *signer.SignKey @@ -49,18 +51,19 @@ type FakeSigner struct { invocationsMutex sync.RWMutex } -func (fake *FakeSigner) Key(arg1 string) (*signer.SignKey, error) { +func (fake *FakeSigner) Key(arg1 context.Context, arg2 string) (*signer.SignKey, error) { fake.keyMutex.Lock() ret, specificReturn := fake.keyReturnsOnCall[len(fake.keyArgsForCall)] fake.keyArgsForCall = append(fake.keyArgsForCall, struct { - arg1 string - }{arg1}) + arg1 context.Context + arg2 string + }{arg1, arg2}) stub := fake.KeyStub fakeReturns := fake.keyReturns - fake.recordInvocation("Key", []interface{}{arg1}) + fake.recordInvocation("Key", []interface{}{arg1, arg2}) fake.keyMutex.Unlock() if stub != nil { - return stub(arg1) + return stub(arg1, arg2) } if specificReturn { return ret.result1, ret.result2 @@ -74,17 +77,17 @@ func (fake *FakeSigner) KeyCallCount() int { return len(fake.keyArgsForCall) } -func (fake *FakeSigner) KeyCalls(stub func(string) (*signer.SignKey, error)) { +func (fake *FakeSigner) KeyCalls(stub func(context.Context, string) (*signer.SignKey, error)) { fake.keyMutex.Lock() defer fake.keyMutex.Unlock() fake.KeyStub = stub } -func (fake *FakeSigner) KeyArgsForCall(i int) string { +func (fake *FakeSigner) KeyArgsForCall(i int) (context.Context, string) { fake.keyMutex.RLock() defer fake.keyMutex.RUnlock() argsForCall := fake.keyArgsForCall[i] - return argsForCall.arg1 + return argsForCall.arg1, argsForCall.arg2 } func (fake *FakeSigner) KeyReturns(result1 *signer.SignKey, result2 error) {