diff --git a/Gopkg.toml b/Gopkg.toml index 8e8e9943dd0a00009a9626d49bf8a20ae8e28e05..e876d43e089050bdf453650749cca9a1fff39ab9 100644 --- a/Gopkg.toml +++ b/Gopkg.toml @@ -1,5 +1,5 @@ [[constraint]] - branch = "master" + branch = "107-Vault_integration" name = "code.vereign.com/code/viam-apis" [[constraint]] diff --git a/config.yaml.sample b/config.yaml.sample index 6ce5342230fd19664bdc33042453401d9346061f..f89eed8f9c48a1f22636edc31195bff0aeb8698d 100644 --- a/config.yaml.sample +++ b/config.yaml.sample @@ -11,26 +11,43 @@ restListenAddress: localhost:7878 viamUUID: viam-system viamSession: viam-session -# Choose a certificate method for providing PEM strings -# 1 = Read from file (*.crt and *.key files) -# 2 = Read from Vault server (this will require additional config information for Vault) -certificateMethod: 1 - -# Read Certificates From Folder and Files -certDir: cert -certFile: server.crt -certKey: server.key -vereignCertFile: vereign_ca.cer -vereignCertKey: vereign_ca.key -caCertFile: ca.crt - # Maximum Message Size (in megabytes) maxMessageSize: 64 -# Read Certificates From Vault Server -vaultAddress: http://10.6.10.119:8200 -vaultToken: 00000000-0000-0000-0000-000000000000 -vaultPath: /developers/data/devteam/cert -certificateKey: certificateKey -privateKey: privateKey -caCertificateKey: caCertificateKey +# Certification Access Method +# 1 = Certificate Folder and Files +# 2 = Vault Integration +#certificationMethod: 1 +certificationMethod: 2 + +# Certification URL +# For Method 1: Can be anything (will be ignored) +# For Method 2: IP address and port number of Vault server +#certificationURL: localhost +certificationURL: http://10.6.10.119:8200 + +# Certification Token +# For Method 1: Can be anything (will be ignored) +# For Method 2: Vault authentication token +#certificationToken: . +certificationToken: YOUR_VAULT_TOKEN + +# Certification Path +# For Method 1: The full path of the folder where the certificate and private key files are stored +# For Method 2: Base mount path for certificate secrets +#certificationPath: /home/ocengiz/Documents/Vereign/03-Codebase/cert +certificationPath: /developers/data/devteam/cert + +# Certification Files +# For Method 1: The name of the files for certification +# For Method 2: The name of the secret keys on Vault +#certificationCertFile: localhost.crt +certificationCertFile: certificateKey +#certificationKeyFile: localhost.key +certificationKeyFile: privateKey +#certificationCaCertFile: ca.crt +certificationCaCertFile: caCertificateKey +#certificationVereignCertFile: vereign_ca.cer +certificationVereignCertFile: vereignCaCertificateKey +#certificationVereignKeyFile: vereign_ca.key +certificationVereignKeyFile: vereignCaPrivateKey \ No newline at end of file diff --git a/handler/generate_keypair.go b/handler/generate_keypair.go index 578008e483282bb428dc7baafe4781d8a715a52a..e9d4403806df659a34be139edf9d01a58611acb3 100644 --- a/handler/generate_keypair.go +++ b/handler/generate_keypair.go @@ -40,7 +40,7 @@ func (s *KeyStorageServerImpl) GenerateKeyPair(ctx context.Context, auth := s.CreateAuthentication(ctx) client := &client.DataStorageClientImpl{} - client.SetUpClient(auth, s.DataStorageUrl, s.CertFilePath, s.KeyFilePath, s.CaCertFilePath, s.MaxMessageSize) + client.SetUpClient(auth, s.DataStorageUrl, s.CertPEM, s.KeyPEM, s.CaCertPEM, s.MaxMessageSize) defer client.CloseClient() generateKeyPairResponse := &api.GenerateKeyPairResponse{} @@ -94,7 +94,7 @@ func (s *KeyStorageServerImpl) GenerateKeyPair(ctx context.Context, } if generateKeyPairResponse.StatusList == nil || len(generateKeyPairResponse.StatusList) == 0 { - encryptedAesKeyBytes, err := rsaEncryptWithServerKey(s.VereignCertFilePath, aesKeyBytes, []byte("aeskeys")) + encryptedAesKeyBytes, err := rsaEncryptWithServerKey(s.VereignCertPEM, aesKeyBytes, []byte("aeskeys")) if err != nil { log.Printf("Error: %v", err) generateKeyPairResponse.StatusList = utils.AddStatus(generateKeyPairResponse.StatusList, @@ -109,7 +109,7 @@ func (s *KeyStorageServerImpl) GenerateKeyPair(ctx context.Context, } if generateKeyPairResponse.StatusList == nil || len(generateKeyPairResponse.StatusList) == 0 { - encryptedPrivateKeyNonceBytes, err := rsaEncryptWithServerKey(s.VereignCertFilePath, privateKeyNonce, []byte("nonce")) + encryptedPrivateKeyNonceBytes, err := rsaEncryptWithServerKey(s.VereignCertPEM, privateKeyNonce, []byte("nonce")) if err != nil { log.Printf("Error: %v", err) generateKeyPairResponse.StatusList = utils.AddStatus(generateKeyPairResponse.StatusList, @@ -176,8 +176,8 @@ func generateKeyPair(keySize int) ([]byte, []byte, error) { return privateKeyBytes, publicKeyBytes, nil } -func rsaEncryptWithServerKey(certFilePath string, message []byte, label []byte) ([]byte, error) { - serverCertificate, err := readCertificateFromFile(certFilePath) +func rsaEncryptWithServerKey(certPEM []byte, message []byte, label []byte) ([]byte, error) { + serverCertificate, err := readCertificateFromPEM(certPEM) if err != nil { log.Printf("Error: %v", err) return nil, err diff --git a/handler/handler.go b/handler/handler.go index 535a449764488ba4a033c7c2e1968e399164a001..a5f98c712e64c0d3e9bc910de0f2573278eba604 100644 --- a/handler/handler.go +++ b/handler/handler.go @@ -19,7 +19,6 @@ package handler import ( "log" - "io/ioutil" "strings" "code.vereign.com/code/viam-apis/versions" @@ -36,11 +35,11 @@ import ( // Server represents the gRPC server type KeyStorageServerImpl struct { DataStorageUrl string - CertFilePath string - KeyFilePath string - CaCertFilePath string - VereignCertFilePath string - VereignPrivateKeyFilePath string + CertPEM []byte + KeyPEM []byte + CaCertPEM []byte + VereignCertPEM []byte + VereignPrivateKeyPEM []byte MaxMessageSize int } @@ -66,23 +65,23 @@ func (s *KeyStorageServerImpl) GetKey(ctx context.Context, in *api.GetKeyRequest auth := s.CreateAuthentication(ctx) client := &client.DataStorageClientImpl{} - client.SetUpClient(auth, s.DataStorageUrl, s.CertFilePath, s.KeyFilePath, s.CaCertFilePath, s.MaxMessageSize) + client.SetUpClient(auth, s.DataStorageUrl, s.CertPEM, s.KeyPEM, s.CaCertPEM, s.MaxMessageSize) defer client.CloseClient() getKeyResponse := &api.GetKeyResponse{} if in.KeyType == api.KeyType_CERTIFICATE && in.Uuid == "root" { key := &api.Key{} - data, err := ioutil.ReadFile(s.VereignCertFilePath) + /*data, err := ioutil.ReadFile(s.VereignCertFilePath) if err != nil { log.Printf("Error: %v", err) getKeyResponse.StatusList = utils.AddStatus(getKeyResponse.StatusList, "400", api.StatusType_ERROR, "Can not get root certificate") return getKeyResponse, nil - } + }*/ - key.Content = data + key.Content = s.VereignCertPEM key.Revoked = false getKeyResponse.Key = key @@ -114,7 +113,7 @@ func (s *KeyStorageServerImpl) SetKey(ctx context.Context, in *api.SetKeyRequest auth := s.CreateAuthentication(ctx) client := &client.DataStorageClientImpl{} - client.SetUpClient(auth, s.DataStorageUrl, s.CertFilePath, s.KeyFilePath, s.CaCertFilePath, s.MaxMessageSize) + client.SetUpClient(auth, s.DataStorageUrl, s.CertPEM, s.KeyPEM, s.CaCertPEM, s.MaxMessageSize) defer client.CloseClient() setKeyResponse := &api.SetKeyResponse{} @@ -162,7 +161,7 @@ func (s *KeyStorageServerImpl) ReserveKeyUUID(ctx context.Context, in *api.Reser auth := s.CreateAuthentication(ctx) client := &client.DataStorageClientImpl{} - client.SetUpClient(auth, s.DataStorageUrl, s.CertFilePath, s.KeyFilePath, s.CaCertFilePath, s.MaxMessageSize) + client.SetUpClient(auth, s.DataStorageUrl, s.CertPEM, s.KeyPEM, s.CaCertPEM, s.MaxMessageSize) defer client.CloseClient() reserveKeyUUIDResponse := &api.ReserveKeyUUIDResponse{} diff --git a/handler/revoke.go b/handler/revoke.go index 4522d3585bd5ede0e9326e78941afb7c5f39238e..a6ecda13aac32c8d67fe71b4eefd510cddff6cee 100644 --- a/handler/revoke.go +++ b/handler/revoke.go @@ -29,7 +29,7 @@ func (s *KeyStorageServerImpl) Revoke(ctx context.Context, in *api.RevokeRequest auth := s.CreateAuthentication(ctx) client := &client.DataStorageClientImpl{} - client.SetUpClient(auth, s.DataStorageUrl, s.CertFilePath, s.KeyFilePath, s.CaCertFilePath, s.MaxMessageSize) + client.SetUpClient(auth, s.DataStorageUrl, s.CertPEM, s.KeyPEM, s.CaCertPEM, s.MaxMessageSize) defer client.CloseClient() revokeResponse := &api.RevokeResponse{} diff --git a/handler/utils.go b/handler/utils.go index f1381396aff5754e53b71ee685ab652095693430..6624e77f6778542f7b37743f32ac1989edf21dff 100644 --- a/handler/utils.go +++ b/handler/utils.go @@ -75,6 +75,30 @@ func handlePutDataErrors(statusList []*api.Status, errors string, err error) []* return statusList } +func readCertificateFromPEM(pemString []byte) (*x509.Certificate, error) { + certificatePemBlock, err := readPemBlockFromBytes(pemString) + if err != nil { + log.Printf("Error: %v", err) + return nil, err + } + + certificate, err := x509.ParseCertificate(certificatePemBlock.Bytes) + if err != nil { + log.Printf("Error: %v", err) + return nil, err + } + + return certificate, nil +} + +func readPemBlockFromBytes(pemString []byte) (*pem.Block, error) { + fileBytes := pemString + + certificatePemBlock, _ := pem.Decode(fileBytes) + + return certificatePemBlock, nil +} + func readCertificateFromFile(fileName string) (*x509.Certificate, error) { certificatePemBlock, err := readPemBlockFromFile(fileName) if err != nil { diff --git a/main.go b/main.go index b49b48d44fc6fcdbe16ff6d77f136c7927e1beee..d8679721b966c434632c11dee46fbbd86f0562c1 100644 --- a/main.go +++ b/main.go @@ -28,29 +28,21 @@ import ( func main() { server.SetConfigValues() - // TODO this should be done via configuration or even a certificate repository - certDir := viper.GetString("certDir") - if certDir == "" { - log.Printf("cert-dir cannot be empty") - return - } - grpcAddress := viper.GetString("grpcListenAddress") restAddress := viper.GetString("restListenAddress") dataStorageAddress := viper.GetString("dataStorageUrl") - - certFilePath := certDir + "/" + viper.GetString("certFile") - privateKeyFilePath := certDir + "/" + viper.GetString("certKey") - caCertFilePath := certDir + "/" + viper.GetString("caCertFile") - vereignCertFilePath := certDir + "/" + viper.GetString("vereignCertFile") - vereignPrivateKeyFilePath := certDir + "/" + viper.GetString("vereignCertKey") + certPem := server.GetCertificatePEM() + keyPem := server.GetPrivateKeyPEM() + caCertPem := server.GetCaCertificatePEM() + vereignCaCertificatePem := server.GetVereignCaCertificatePEM() + vereignCaKeyPem := server.GetVereignCaKeyPEM() maxMessageSize := viper.GetInt("maxMessageSize") // fire the gRPC server in a goroutine go func() { - err := server.StartGRPCServer(grpcAddress, certFilePath, privateKeyFilePath, caCertFilePath, vereignCertFilePath, - vereignPrivateKeyFilePath, dataStorageAddress, maxMessageSize) + err := server.StartGRPCServer(grpcAddress, certPem, keyPem, caCertPem, vereignCaCertificatePem, + vereignCaKeyPem, dataStorageAddress, maxMessageSize) if err != nil { log.Fatalf("failed to start gRPC server: %s", err) } @@ -58,7 +50,7 @@ func main() { // fire the REST server in a goroutine go func() { - err := server.StartRESTServer(restAddress, grpcAddress, certFilePath) + err := server.StartRESTServer(restAddress, grpcAddress, certPem) if err != nil { log.Fatalf("failed to start gRPC server: %s", err) } diff --git a/server/configs.go b/server/configs.go index 28cfff0ebfb13104501448e3b2cd2921fe45856e..9cb3a86794d8595c0321caa8e1f4c4adb69a7f77 100644 --- a/server/configs.go +++ b/server/configs.go @@ -5,6 +5,14 @@ import ( "github.com/spf13/viper" ) +var certificationMethod string +var p PEMReader +var certificatePEM []byte +var privateKeyPEM []byte +var caCertificatePEM []byte +var vereignCaCertificatePEM []byte +var vereignCaKeyPEM []byte + func SetConfigValues() { // Set Default Values For Config Variables @@ -12,20 +20,37 @@ func SetConfigValues() { viper.SetDefault("grpcListenAddress", "localhost:7877") viper.SetDefault("restListenAddress", "localhost:7878") viper.SetDefault("dataStorageUrl", "localhost:7777") - - // Certificates Related - viper.SetDefault("certDir", "cert") - viper.SetDefault("certFile", "server.crt") - viper.SetDefault("certKey", "server.key") - viper.SetDefault("caCertFile", "ca.crt") - viper.SetDefault("vereignCertFile", "vereign_ca.cer") - viper.SetDefault("vereignCertKey", "vereign_ca.key") viper.SetDefault("viamUUID", "viam-system") viper.SetDefault("viamSession", "viam-session") viper.SetDefault("maxMessageSize", 64) + // Certification Related + // File System Defaults + viper.SetDefault("certificationMethod", "1") + viper.SetDefault("certificationURL", ".") + viper.SetDefault("certificationToken", ".") + viper.SetDefault("certificationPath", "cert") + viper.SetDefault("certificationCertFile", "server.crt") + viper.SetDefault("certificationKeyFile", "server.key") + viper.SetDefault("certificationCaCertFile", "ca.crt") + viper.SetDefault("certificationVereignCertFile", "vereign_ca.cer") + viper.SetDefault("certificationVereignKeyFile", "vereign_ca.key") + + /* + // Vault Defaults + viper.SetDefault("certificationMethod", "2") + viper.SetDefault("certificationURL", "http://10.6.10.119:8200") + viper.SetDefault("certificationToken", "") + viper.SetDefault("certificationPath", "/developers/data/devteam/cert") + viper.SetDefault("certificationCertFile", "certificateKey") + viper.SetDefault("certificationKeyFile", "privateKey") + viper.SetDefault("certificationCaCertFile", "caCertificateKey") + viper.SetDefault("certificationVereignCertFile", "vereignCaCertificateKey") + viper.SetDefault("certificationVereignKeyFile", "vereignCaPrivateKey") + */ + // Read Config File viper.SetConfigName("config") viper.AddConfigPath(".") @@ -33,10 +58,53 @@ func SetConfigValues() { log.Printf("can't read config: %s, will use default values", err) } + certificationMethod = viper.GetString("certificationMethod") + if certificationMethod == "1" { + // Read From File System + p = FilePEMReader{certificationURL: viper.GetString("certificationURL"), + certificationToken: viper.GetString("certificationToken"), + certificationPath: viper.GetString("certificationPath"), + certificationCertFile: viper.GetString("certificationCertFile"), + certificationKeyFile: viper.GetString("certificationKeyFile"), + certificationCaCertFile: viper.GetString("certificationCaCertFile"), + certificationVereignCertFile: viper.GetString("certificationVereignCertFile"), + certificationVereignKeyFile: viper.GetString("certificationVereignKeyFile")} + } else if certificationMethod == "2" { + // Read From Vault + p = VaultPEMReader{certificationURL: viper.GetString("certificationURL"), + certificationToken: viper.GetString("certificationToken"), + certificationPath: viper.GetString("certificationPath"), + certificationCertFile: viper.GetString("certificationCertFile"), + certificationKeyFile: viper.GetString("certificationKeyFile"), + certificationCaCertFile: viper.GetString("certificationCaCertFile"), + certificationVereignCertFile: viper.GetString("certificationVereignCertFile"), + certificationVereignKeyFile: viper.GetString("certificationVereignKeyFile")} + } + // Print all config values to log file log.Printf("All Settings From Config:") as := viper.AllSettings() for key, _ := range as { log.Printf("%s => %s", key, viper.GetString(key)) } +} + +func GetCertificatePEM() []byte { + return p.readCertificatePEM() +} + +func GetPrivateKeyPEM() []byte { + return p.readPrivateKeyPEM() +} + +func GetCaCertificatePEM() []byte { + return p.readCaCertificatePEM() +} + +func GetVereignCaCertificatePEM() []byte { + return p.readVereignCaCertificatePEM() +} + +func GetVereignCaKeyPEM() []byte { + return p.readVereignCaKeyPEM() } \ No newline at end of file diff --git a/server/pem_reader.go b/server/pem_reader.go new file mode 100644 index 0000000000000000000000000000000000000000..d6a11251631a9e42fcf43da350bd0687a3a71972 --- /dev/null +++ b/server/pem_reader.go @@ -0,0 +1,217 @@ +package server + +import ( + "log" + "io/ioutil" + vc "github.com/hashicorp/vault/api" +) + +type PEMReader interface { + readCertificatePEM() []byte + readPrivateKeyPEM() []byte + readCaCertificatePEM() []byte + readVereignCaCertificatePEM() []byte + readVereignCaKeyPEM() []byte +} + +type FilePEMReader struct { + certificationURL string + certificationToken string + certificationPath string + certificationCertFile string + certificationKeyFile string + certificationCaCertFile string + certificationVereignCertFile string + certificationVereignKeyFile string +} + +func (f FilePEMReader) readCertificatePEM() []byte { + pem, err := ioutil.ReadFile(f.certificationPath + "/" + f.certificationCertFile) + if err != nil { + log.Printf("Error: %v", err) + return []byte("") + } + return pem +} + +func (f FilePEMReader) readPrivateKeyPEM() []byte { + pem, err := ioutil.ReadFile(f.certificationPath + "/" + f.certificationKeyFile) + if err != nil { + log.Printf("Error: %v", err) + return []byte("") + } + return pem +} + +func (f FilePEMReader) readCaCertificatePEM() []byte { + pem, err := ioutil.ReadFile(f.certificationPath + "/" + f.certificationCaCertFile) + if err != nil { + log.Printf("Error: %v", err) + return []byte("") + } + return pem +} + +func (f FilePEMReader) readVereignCaCertificatePEM() []byte { + pem, err := ioutil.ReadFile(f.certificationPath + "/" + f.certificationVereignCertFile) + if err != nil { + log.Printf("Error: %v", err) + return []byte("") + } + return pem +} + +func (f FilePEMReader) readVereignCaKeyPEM() []byte { + pem, err := ioutil.ReadFile(f.certificationPath + "/" + f.certificationVereignKeyFile) + if err != nil { + log.Printf("Error: %v", err) + return []byte("") + } + return pem +} + +type VaultPEMReader struct { + certificationURL string + certificationToken string + certificationPath string + certificationCertFile string + certificationKeyFile string + certificationCaCertFile string + certificationVereignCertFile string + certificationVereignKeyFile string +} + +func (v VaultPEMReader) readCertificatePEM() []byte { + client, err := vc.NewClient(&vc.Config{ + Address: v.certificationURL, + }) + + if err != nil { + log.Printf("Error: VAULT Can't create client, %s", err) + } + + client.SetToken(v.certificationToken) + + keyname := v.certificationPath + secretValues, err := client.Logical().Read(keyname) + if err != nil { + log.Printf("Error: VAULT Can't read value, %s", err) + } + + pemMap := secretValues.Data["data"].(map[string]interface{}) + + for propName, propValue := range pemMap { + if propName == v.certificationCertFile { + return []byte(propValue.(string)) + } + } + return []byte("") +} + +func (v VaultPEMReader) readPrivateKeyPEM() []byte { + client, err := vc.NewClient(&vc.Config{ + Address: v.certificationURL, + }) + + if err != nil { + log.Printf("Error: VAULT Can't create client, %s", err) + } + + client.SetToken(v.certificationToken) + + keyname := v.certificationPath + secretValues, err := client.Logical().Read(keyname) + if err != nil { + log.Printf("Error: VAULT Can't read value, %s", err) + } + + pemMap := secretValues.Data["data"].(map[string]interface{}) + + for propName, propValue := range pemMap { + if propName == v.certificationKeyFile { + return []byte(propValue.(string)) + } + } + return []byte("") +} + +func (v VaultPEMReader) readCaCertificatePEM() []byte { + client, err := vc.NewClient(&vc.Config{ + Address: v.certificationURL, + }) + + if err != nil { + log.Printf("Error: VAULT Can't create client, %s", err) + } + + client.SetToken(v.certificationToken) + + keyname := v.certificationPath + secretValues, err := client.Logical().Read(keyname) + if err != nil { + log.Printf("Error: VAULT Can't read value, %s", err) + } + + pemMap := secretValues.Data["data"].(map[string]interface{}) + + for propName, propValue := range pemMap { + if propName == v.certificationCaCertFile { + return []byte(propValue.(string)) + } + } + return []byte("") +} + +func (v VaultPEMReader) readVereignCaCertificatePEM() []byte { + client, err := vc.NewClient(&vc.Config{ + Address: v.certificationURL, + }) + + if err != nil { + log.Printf("Error: VAULT Can't create client, %s", err) + } + + client.SetToken(v.certificationToken) + + keyname := v.certificationPath + secretValues, err := client.Logical().Read(keyname) + if err != nil { + log.Printf("Error: VAULT Can't read value, %s", err) + } + + pemMap := secretValues.Data["data"].(map[string]interface{}) + + for propName, propValue := range pemMap { + if propName == v.certificationVereignCertFile { + return []byte(propValue.(string)) + } + } + return []byte("") +} + +func (v VaultPEMReader) readVereignCaKeyPEM() []byte { + client, err := vc.NewClient(&vc.Config{ + Address: v.certificationURL, + }) + + if err != nil { + log.Printf("Error: VAULT Can't create client, %s", err) + } + + client.SetToken(v.certificationToken) + + keyname := v.certificationPath + secretValues, err := client.Logical().Read(keyname) + if err != nil { + log.Printf("Error: VAULT Can't read value, %s", err) + } + + pemMap := secretValues.Data["data"].(map[string]interface{}) + + for propName, propValue := range pemMap { + if propName == v.certificationVereignKeyFile { + return []byte(propValue.(string)) + } + } + return []byte("") +} \ No newline at end of file diff --git a/server/server.go b/server/server.go index 5dcbe0eb7da5169c56898390935e35fcd98a7325..654be0e361c695c445361b345866bf79f1bdf847 100644 --- a/server/server.go +++ b/server/server.go @@ -23,11 +23,10 @@ import ( "net" "net/http" "strings" - + "crypto/x509" + "crypto/tls" "github.com/grpc-ecosystem/grpc-gateway/runtime" - "golang.org/x/net/context" - "code.vereign.com/code/key-storage-agent/handler" "code.vereign.com/code/key-storage-agent/session" "code.vereign.com/code/viam-apis/authentication" @@ -46,9 +45,9 @@ const ( clientIDKey contextKey = iota ) -var pkgCertFile string -var pkgKeyFile string -var pkgCaCertFile string +var pkgCertPEM []byte +var pkgKeyPEM []byte +var pkgCaCertPEM []byte func credMatcher(headerName string) (mdName string, ok bool) { if headerName == "Session" { @@ -72,7 +71,7 @@ func authenticateClient(ctx context.Context, s *handler.KeyStorageServerImpl, in } sessionClient := &client.DataStorageClientImpl{} - sessionClient.SetUpClient(viamAuth, viper.GetString("dataStorageUrl"), pkgCertFile, pkgKeyFile, pkgCaCertFile, viper.GetInt("maxMessageSize")) + sessionClient.SetUpClient(viamAuth, viper.GetString("dataStorageUrl"), pkgCertPEM, pkgKeyPEM, pkgCaCertPEM, viper.GetInt("maxMessageSize")) defer sessionClient.CloseClient() if clientAuth.Uuid == viamAuth.Uuid { @@ -110,10 +109,10 @@ func unaryInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServ return handler1(ctx, req) } -func StartGRPCServer(address, certFilePath, privateKeyFilePath, caCertFilePath, vereignCertFilePath, vereignPrivateKeyFilePath, dataStorageAddress string, maxMessageSize int) error { - pkgCertFile = certFilePath - pkgKeyFile = privateKeyFilePath - pkgCaCertFile = caCertFilePath +func StartGRPCServer(address string, certPEM, privateKeyPEM, caCertPEM, vereignCertPEM, vereignPrivateKeyPEM []byte, dataStorageAddress string, maxMessageSize int) error { + pkgCertPEM = certPEM + pkgKeyPEM = privateKeyPEM + pkgCaCertPEM = caCertPEM // create a listener on TCP port lis, err := net.Listen("tcp", address) @@ -124,21 +123,40 @@ func StartGRPCServer(address, certFilePath, privateKeyFilePath, caCertFilePath, // create a server instance s := handler.KeyStorageServerImpl{ - DataStorageUrl: dataStorageAddress, - CertFilePath: certFilePath, - KeyFilePath: privateKeyFilePath, - CaCertFilePath: caCertFilePath, - VereignCertFilePath: vereignCertFilePath, - VereignPrivateKeyFilePath: vereignPrivateKeyFilePath, - MaxMessageSize: maxMessageSize, + DataStorageUrl: dataStorageAddress, + CertPEM: certPEM, + KeyPEM: privateKeyPEM, + CaCertPEM: caCertPEM, + VereignCertPEM: vereignCertPEM, + VereignPrivateKeyPEM: vereignPrivateKeyPEM, + MaxMessageSize: maxMessageSize, } - // Create the TLS credentials - creds, err := credentials.NewServerTLSFromFile(certFilePath, privateKeyFilePath) + // Load the certificates from PEM Strings + certificate, err := tls.X509KeyPair(certPEM, privateKeyPEM) + if err != nil { log.Printf("Error: %v", err) - return fmt.Errorf("could not load TLS keys: %s", err) + return fmt.Errorf("could not load server key pair: %s", err) + } + + // Create a certificate pool from the certificate authority + // Get the SystemCertPool, continue with an empty pool on error + certPool, _ := x509.SystemCertPool() + if certPool == nil { + certPool = x509.NewCertPool() } + + if ok := certPool.AppendCertsFromPEM(caCertPEM); !ok { + return fmt.Errorf("failed to append server certs") + } + + // Create the TLS credentials + creds := credentials.NewTLS(&tls.Config{ + //ClientAuth: tls.RequireAndVerifyClientCert, + Certificates: []tls.Certificate{certificate}, + ClientCAs: certPool, + }) // Create an array of gRPC options with the credentials opts := []grpc.ServerOption{ @@ -162,19 +180,25 @@ func StartGRPCServer(address, certFilePath, privateKeyFilePath, caCertFilePath, return nil } -func StartRESTServer(address, grpcAddress, certFile string) error { +func StartRESTServer(address, grpcAddress string, certPEM []byte) error { ctx := context.Background() ctx, cancel := context.WithCancel(ctx) defer cancel() mux := runtime.NewServeMux(runtime.WithIncomingHeaderMatcher(credMatcher)) - creds, err := credentials.NewClientTLSFromFile(certFile, "") - if err != nil { - log.Printf("Error: %v", err) - return fmt.Errorf("could not load TLS certificate: %s", err) + certPool, err := x509.SystemCertPool() + if certPool == nil { + certPool = x509.NewCertPool() } + // Append the client certificates from the CA + if ok := certPool.AppendCertsFromPEM(certPEM); !ok { + return fmt.Errorf("failed to append client certs") + } + + creds := credentials.NewClientTLSFromCert(certPool, "") + // Setup the client gRPC options opts := []grpc.DialOption{grpc.WithTransportCredentials(creds)}