From 47f17a0be8fc8c7a483958ea23fb867c9699d17b Mon Sep 17 00:00:00 2001
From: Lyuben Penkovski <penkovski@gmail.com>
Date: Wed, 15 Jun 2022 09:55:33 +0300
Subject: [PATCH] Add vault client initialization check

Make a request when creating the client to see if the Vault
is unsealed and available for operation.
---
 cmd/signer/main.go                    |  4 ++--
 internal/clients/vault/client.go      | 11 ++++++++++-
 internal/clients/vault/client_test.go |  6 +++---
 3 files changed, 15 insertions(+), 6 deletions(-)

diff --git a/cmd/signer/main.go b/cmd/signer/main.go
index a11b6d6..d93f867 100644
--- a/cmd/signer/main.go
+++ b/cmd/signer/main.go
@@ -50,9 +50,9 @@ func main() {
 
 	httpClient := httpClient()
 
-	vault, err := vault.New(cfg.Vault.Addr, cfg.Vault.Token, httpClient)
+	vault, err := vault.New(cfg.Vault.Addr, cfg.Vault.Token, true, httpClient)
 	if err != nil {
-		logger.Fatal("cannot create vault client", zap.Error(err))
+		logger.Fatal("cannot initialize vault client", zap.Error(err))
 	}
 
 	// create services
diff --git a/internal/clients/vault/client.go b/internal/clients/vault/client.go
index c5d8b46..d5e1993 100644
--- a/internal/clients/vault/client.go
+++ b/internal/clients/vault/client.go
@@ -24,7 +24,7 @@ type Client struct {
 }
 
 // New creates a Hashicorp Vault client.
-func New(addr string, token string, httpClient *http.Client) (*Client, error) {
+func New(addr string, token string, probe bool, httpClient *http.Client) (*Client, error) {
 	cfg := vaultpkg.DefaultConfig()
 	cfg.Address = addr
 	cfg.HttpClient = httpClient
@@ -35,6 +35,15 @@ func New(addr string, token string, httpClient *http.Client) (*Client, error) {
 
 	client.SetToken(token)
 
+	// If probe is set, the client will try to query the vault to check if
+	// it's unsealed and ready for operation. This is used mostly so unit tests
+	// can bypass the check as they don't work against a real Vault.
+	if probe {
+		if _, err = client.Sys().Capabilities(token, pathSign); err != nil {
+			return nil, err
+		}
+	}
+
 	return &Client{cfg: cfg, client: client}, nil
 }
 
diff --git a/internal/clients/vault/client_test.go b/internal/clients/vault/client_test.go
index a68d754..66258b0 100644
--- a/internal/clients/vault/client_test.go
+++ b/internal/clients/vault/client_test.go
@@ -65,7 +65,7 @@ func TestClient_Key(t *testing.T) {
 	for _, test := range tests {
 		t.Run(test.name, func(t *testing.T) {
 			vaultsrv := httptest.NewServer(test.handler)
-			client, err := vault.New(vaultsrv.URL, "token", http.DefaultClient)
+			client, err := vault.New(vaultsrv.URL, "token", false, http.DefaultClient)
 			assert.NoError(t, err)
 
 			res, err := client.Key(test.key)
@@ -90,7 +90,7 @@ func TestClient_WithKey(t *testing.T) {
 		w.WriteHeader(http.StatusNotFound)
 	}))
 
-	c1, err := vault.New(vaultsrv.URL, "token", http.DefaultClient)
+	c1, err := vault.New(vaultsrv.URL, "token", false, http.DefaultClient)
 	assert.NoError(t, err)
 
 	c2 := c1.WithKey("mytest-key123")
@@ -147,7 +147,7 @@ func TestClient_Sign(t *testing.T) {
 	for _, test := range tests {
 		t.Run(test.name, func(t *testing.T) {
 			vaultsrv := httptest.NewServer(test.handler)
-			client, err := vault.New(vaultsrv.URL, "token", http.DefaultClient)
+			client, err := vault.New(vaultsrv.URL, "token", false, http.DefaultClient)
 			assert.NoError(t, err)
 
 			res, err := client.Sign(test.data)
-- 
GitLab