diff --git a/cmd/signer/main.go b/cmd/signer/main.go index a11b6d61916e33776261c703e8bc5b41e1002e6a..d93f8675d181af5c26a0a356c8214e675bdffc4c 100644 --- a/cmd/signer/main.go +++ b/cmd/signer/main.go @@ -50,9 +50,9 @@ func main() { httpClient := httpClient() - vault, err := vault.New(cfg.Vault.Addr, cfg.Vault.Token, httpClient) + vault, err := vault.New(cfg.Vault.Addr, cfg.Vault.Token, true, httpClient) if err != nil { - logger.Fatal("cannot create vault client", zap.Error(err)) + logger.Fatal("cannot initialize vault client", zap.Error(err)) } // create services diff --git a/internal/clients/vault/client.go b/internal/clients/vault/client.go index c5d8b46a69e03e9396c1e87c9a51a8c090142199..d5e199372cbb10d8a5091175a826c946c16b2e6b 100644 --- a/internal/clients/vault/client.go +++ b/internal/clients/vault/client.go @@ -24,7 +24,7 @@ type Client struct { } // New creates a Hashicorp Vault client. -func New(addr string, token string, httpClient *http.Client) (*Client, error) { +func New(addr string, token string, probe bool, httpClient *http.Client) (*Client, error) { cfg := vaultpkg.DefaultConfig() cfg.Address = addr cfg.HttpClient = httpClient @@ -35,6 +35,15 @@ func New(addr string, token string, httpClient *http.Client) (*Client, error) { client.SetToken(token) + // If probe is set, the client will try to query the vault to check if + // it's unsealed and ready for operation. This is used mostly so unit tests + // can bypass the check as they don't work against a real Vault. + if probe { + if _, err = client.Sys().Capabilities(token, pathSign); err != nil { + return nil, err + } + } + return &Client{cfg: cfg, client: client}, nil } diff --git a/internal/clients/vault/client_test.go b/internal/clients/vault/client_test.go index a68d75407c0202f2b4b5be40188dd0e8984126b2..66258b01eaa6589094ecec26cf9b56358f5a8919 100644 --- a/internal/clients/vault/client_test.go +++ b/internal/clients/vault/client_test.go @@ -65,7 +65,7 @@ func TestClient_Key(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { vaultsrv := httptest.NewServer(test.handler) - client, err := vault.New(vaultsrv.URL, "token", http.DefaultClient) + client, err := vault.New(vaultsrv.URL, "token", false, http.DefaultClient) assert.NoError(t, err) res, err := client.Key(test.key) @@ -90,7 +90,7 @@ func TestClient_WithKey(t *testing.T) { w.WriteHeader(http.StatusNotFound) })) - c1, err := vault.New(vaultsrv.URL, "token", http.DefaultClient) + c1, err := vault.New(vaultsrv.URL, "token", false, http.DefaultClient) assert.NoError(t, err) c2 := c1.WithKey("mytest-key123") @@ -147,7 +147,7 @@ func TestClient_Sign(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { vaultsrv := httptest.NewServer(test.handler) - client, err := vault.New(vaultsrv.URL, "token", http.DefaultClient) + client, err := vault.New(vaultsrv.URL, "token", false, http.DefaultClient) assert.NoError(t, err) res, err := client.Sign(test.data)