Skip to content
Snippets Groups Projects
Commit 03812203 authored by Gospodin Bodurov's avatar Gospodin Bodurov
Browse files

Merge branch 'devops22-code-cleanup' into 'master'

Devops22 code cleanup

See merge request !25
parents f53049a1 f8ce57c8
No related branches found
No related tags found
2 merge requests!25Devops22 code cleanup,!15WIP: Master
...@@ -7,6 +7,10 @@ dataStorageUrl: localhost:7777 ...@@ -7,6 +7,10 @@ dataStorageUrl: localhost:7777
grpcListenAddress: localhost:7877 grpcListenAddress: localhost:7877
restListenAddress: localhost:7878 restListenAddress: localhost:7878
# VIAM Variables
viamUUID: viam-system
viamSession: viam-session
# Choose a certificate method for providing PEM strings # Choose a certificate method for providing PEM strings
# 1 = Read from file (*.crt and *.key files) # 1 = Read from file (*.crt and *.key files)
# 2 = Read from Vault server (this will require additional config information for Vault) # 2 = Read from Vault server (this will require additional config information for Vault)
......
...@@ -18,6 +18,7 @@ along with this program. If not, see <http://www.gnu.org/licenses/>. ...@@ -18,6 +18,7 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
package handler package handler
import ( import (
"log"
"crypto/aes" "crypto/aes"
"crypto/cipher" "crypto/cipher"
"crypto/rand" "crypto/rand"
...@@ -46,24 +47,28 @@ func (s *KeyStorageServerImpl) GenerateKeyPair(ctx context.Context, ...@@ -46,24 +47,28 @@ func (s *KeyStorageServerImpl) GenerateKeyPair(ctx context.Context,
uuid, err := generateUnusedUUID(client) uuid, err := generateUnusedUUID(client)
if err != nil { if err != nil {
log.Printf("Error: %v", err)
generateKeyPairResponse.StatusList = utils.AddStatus(generateKeyPairResponse.StatusList, generateKeyPairResponse.StatusList = utils.AddStatus(generateKeyPairResponse.StatusList,
"500", api.StatusType_ERROR, err.Error()) "500", api.StatusType_ERROR, err.Error())
} }
privateKeyBytes, publicKeyBytes, err := generateKeyPair(int(in.KeySize)) privateKeyBytes, publicKeyBytes, err := generateKeyPair(int(in.KeySize))
if err != nil { if err != nil {
log.Printf("Error: %v", err)
generateKeyPairResponse.StatusList = utils.AddStatus(generateKeyPairResponse.StatusList, generateKeyPairResponse.StatusList = utils.AddStatus(generateKeyPairResponse.StatusList,
"500", api.StatusType_ERROR, err.Error()) "500", api.StatusType_ERROR, err.Error())
} }
aesKeyBytes, err := generateRandomSequence(256) aesKeyBytes, err := generateRandomSequence(256)
if err != nil { if err != nil {
log.Printf("Error: %v", err)
generateKeyPairResponse.StatusList = utils.AddStatus(generateKeyPairResponse.StatusList, generateKeyPairResponse.StatusList = utils.AddStatus(generateKeyPairResponse.StatusList,
"500", api.StatusType_ERROR, err.Error()) "500", api.StatusType_ERROR, err.Error())
} }
encryptedPrivateKeyBytes, privateKeyNonce, err := aesEncrypt(aesKeyBytes, privateKeyBytes) encryptedPrivateKeyBytes, privateKeyNonce, err := aesEncrypt(aesKeyBytes, privateKeyBytes)
if err != nil { if err != nil {
log.Printf("Error: %v", err)
generateKeyPairResponse.StatusList = utils.AddStatus(generateKeyPairResponse.StatusList, generateKeyPairResponse.StatusList = utils.AddStatus(generateKeyPairResponse.StatusList,
"500", api.StatusType_ERROR, err.Error()) "500", api.StatusType_ERROR, err.Error())
return generateKeyPairResponse, nil return generateKeyPairResponse, nil
...@@ -91,6 +96,7 @@ func (s *KeyStorageServerImpl) GenerateKeyPair(ctx context.Context, ...@@ -91,6 +96,7 @@ func (s *KeyStorageServerImpl) GenerateKeyPair(ctx context.Context,
if generateKeyPairResponse.StatusList == nil || len(generateKeyPairResponse.StatusList) == 0 { if generateKeyPairResponse.StatusList == nil || len(generateKeyPairResponse.StatusList) == 0 {
encryptedAesKeyBytes, err := rsaEncryptWithServerKey(s.VereignCertFilePath, aesKeyBytes, []byte("aeskeys")) encryptedAesKeyBytes, err := rsaEncryptWithServerKey(s.VereignCertFilePath, aesKeyBytes, []byte("aeskeys"))
if err != nil { if err != nil {
log.Printf("Error: %v", err)
generateKeyPairResponse.StatusList = utils.AddStatus(generateKeyPairResponse.StatusList, generateKeyPairResponse.StatusList = utils.AddStatus(generateKeyPairResponse.StatusList,
"500", api.StatusType_ERROR, err.Error()) "500", api.StatusType_ERROR, err.Error())
return generateKeyPairResponse, nil return generateKeyPairResponse, nil
...@@ -105,6 +111,7 @@ func (s *KeyStorageServerImpl) GenerateKeyPair(ctx context.Context, ...@@ -105,6 +111,7 @@ func (s *KeyStorageServerImpl) GenerateKeyPair(ctx context.Context,
if generateKeyPairResponse.StatusList == nil || len(generateKeyPairResponse.StatusList) == 0 { if generateKeyPairResponse.StatusList == nil || len(generateKeyPairResponse.StatusList) == 0 {
encryptedPrivateKeyNonceBytes, err := rsaEncryptWithServerKey(s.VereignCertFilePath, privateKeyNonce, []byte("nonce")) encryptedPrivateKeyNonceBytes, err := rsaEncryptWithServerKey(s.VereignCertFilePath, privateKeyNonce, []byte("nonce"))
if err != nil { if err != nil {
log.Printf("Error: %v", err)
generateKeyPairResponse.StatusList = utils.AddStatus(generateKeyPairResponse.StatusList, generateKeyPairResponse.StatusList = utils.AddStatus(generateKeyPairResponse.StatusList,
"500", api.StatusType_ERROR, err.Error()) "500", api.StatusType_ERROR, err.Error())
return generateKeyPairResponse, nil return generateKeyPairResponse, nil
...@@ -128,11 +135,13 @@ func (s *KeyStorageServerImpl) GenerateKeyPair(ctx context.Context, ...@@ -128,11 +135,13 @@ func (s *KeyStorageServerImpl) GenerateKeyPair(ctx context.Context,
func generateKeyPair(keySize int) ([]byte, []byte, error) { func generateKeyPair(keySize int) ([]byte, []byte, error) {
privateKey, err := rsa.GenerateKey(rand.Reader, keySize) privateKey, err := rsa.GenerateKey(rand.Reader, keySize)
if err != nil { if err != nil {
log.Printf("Error: %v", err)
return nil, nil, err return nil, nil, err
} }
err = privateKey.Validate() err = privateKey.Validate()
if err != nil { if err != nil {
log.Printf("Error: %v", err)
return nil, nil, err return nil, nil, err
} }
...@@ -140,6 +149,7 @@ func generateKeyPair(keySize int) ([]byte, []byte, error) { ...@@ -140,6 +149,7 @@ func generateKeyPair(keySize int) ([]byte, []byte, error) {
pkcs8PrivateKeyBytes, err := x509.MarshalPKCS8PrivateKey(privateKey) pkcs8PrivateKeyBytes, err := x509.MarshalPKCS8PrivateKey(privateKey)
if err != nil { if err != nil {
log.Printf("Error: %v", err)
return nil, nil, err return nil, nil, err
} }
...@@ -152,6 +162,7 @@ func generateKeyPair(keySize int) ([]byte, []byte, error) { ...@@ -152,6 +162,7 @@ func generateKeyPair(keySize int) ([]byte, []byte, error) {
pkixPublicKeyBytes, err := x509.MarshalPKIXPublicKey(publicKey) pkixPublicKeyBytes, err := x509.MarshalPKIXPublicKey(publicKey)
if err != nil { if err != nil {
log.Printf("Error: %v", err)
return nil, nil, err return nil, nil, err
} }
...@@ -168,12 +179,14 @@ func generateKeyPair(keySize int) ([]byte, []byte, error) { ...@@ -168,12 +179,14 @@ func generateKeyPair(keySize int) ([]byte, []byte, error) {
func rsaEncryptWithServerKey(certFilePath string, message []byte, label []byte) ([]byte, error) { func rsaEncryptWithServerKey(certFilePath string, message []byte, label []byte) ([]byte, error) {
serverCertificate, err := readCertificateFromFile(certFilePath) serverCertificate, err := readCertificateFromFile(certFilePath)
if err != nil { if err != nil {
log.Printf("Error: %v", err)
return nil, err return nil, err
} }
serverPublicKey := serverCertificate.PublicKey.(*rsa.PublicKey) serverPublicKey := serverCertificate.PublicKey.(*rsa.PublicKey)
encryptedMessageBytes, err := rsa.EncryptOAEP(sha256.New(), rand.Reader, serverPublicKey, message, label) encryptedMessageBytes, err := rsa.EncryptOAEP(sha256.New(), rand.Reader, serverPublicKey, message, label)
if err != nil { if err != nil {
log.Printf("Error: %v", err)
return nil, err return nil, err
} }
...@@ -183,16 +196,19 @@ func rsaEncryptWithServerKey(certFilePath string, message []byte, label []byte) ...@@ -183,16 +196,19 @@ func rsaEncryptWithServerKey(certFilePath string, message []byte, label []byte)
func aesEncrypt(aesKey []byte, message []byte) ([]byte, []byte, error) { func aesEncrypt(aesKey []byte, message []byte) ([]byte, []byte, error) {
block, err := aes.NewCipher(aesKey) block, err := aes.NewCipher(aesKey)
if err != nil { if err != nil {
log.Printf("Error: %v", err)
return nil, nil, err return nil, nil, err
} }
aesgcm, err := cipher.NewGCM(block) aesgcm, err := cipher.NewGCM(block)
if err != nil { if err != nil {
log.Printf("Error: %v", err)
return nil, nil, err return nil, nil, err
} }
nonce, err := generateRandomSequence(aesgcm.NonceSize() * 8) nonce, err := generateRandomSequence(aesgcm.NonceSize() * 8)
if err != nil { if err != nil {
log.Printf("Error: %v", err)
return nil, nil, err return nil, nil, err
} }
...@@ -206,6 +222,7 @@ func generateRandomSequence(keySize int) ([]byte, error) { ...@@ -206,6 +222,7 @@ func generateRandomSequence(keySize int) ([]byte, error) {
_, err := rand.Read(key) _, err := rand.Read(key)
if err != nil { if err != nil {
log.Printf("Error: %v", err)
return nil, err return nil, err
} }
......
...@@ -18,7 +18,7 @@ along with this program. If not, see <http://www.gnu.org/licenses/>. ...@@ -18,7 +18,7 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
package handler package handler
import ( import (
"fmt" "log"
"io/ioutil" "io/ioutil"
"strings" "strings"
...@@ -76,6 +76,7 @@ func (s *KeyStorageServerImpl) GetKey(ctx context.Context, in *api.GetKeyRequest ...@@ -76,6 +76,7 @@ func (s *KeyStorageServerImpl) GetKey(ctx context.Context, in *api.GetKeyRequest
data, err := ioutil.ReadFile(s.VereignCertFilePath) data, err := ioutil.ReadFile(s.VereignCertFilePath)
if err != nil { if err != nil {
log.Printf("Error: %v", err)
getKeyResponse.StatusList = utils.AddStatus(getKeyResponse.StatusList, getKeyResponse.StatusList = utils.AddStatus(getKeyResponse.StatusList,
"400", api.StatusType_ERROR, "Can not get root certificate") "400", api.StatusType_ERROR, "Can not get root certificate")
return getKeyResponse, nil return getKeyResponse, nil
...@@ -168,6 +169,7 @@ func (s *KeyStorageServerImpl) ReserveKeyUUID(ctx context.Context, in *api.Reser ...@@ -168,6 +169,7 @@ func (s *KeyStorageServerImpl) ReserveKeyUUID(ctx context.Context, in *api.Reser
uuid, err := generateUnusedUUID(client) uuid, err := generateUnusedUUID(client)
if err != nil { if err != nil {
log.Printf("Error: %v", err)
reserveKeyUUIDResponse.StatusList = utils.AddStatus(reserveKeyUUIDResponse.StatusList, reserveKeyUUIDResponse.StatusList = utils.AddStatus(reserveKeyUUIDResponse.StatusList,
"500", api.StatusType_INFO, err.Error()) "500", api.StatusType_INFO, err.Error())
} }
...@@ -199,6 +201,6 @@ func (s *KeyStorageServerImpl) ReserveKeyUUID(ctx context.Context, in *api.Reser ...@@ -199,6 +201,6 @@ func (s *KeyStorageServerImpl) ReserveKeyUUID(ctx context.Context, in *api.Reser
} }
func (s *KeyStorageServerImpl) GetVersionKSA(ctx context.Context, in *api.GetVersionKSAMessage) (*api.GetVersionKSAResponseMessage, error) { func (s *KeyStorageServerImpl) GetVersionKSA(ctx context.Context, in *api.GetVersionKSAMessage) (*api.GetVersionKSAResponseMessage, error) {
fmt.Println("Version: " + version) log.Println("Version: " + version)
return &api.GetVersionKSAResponseMessage{Version: version, Errors: ""}, nil return &api.GetVersionKSAResponseMessage{Version: version, Errors: ""}, nil
} }
\ No newline at end of file
...@@ -22,6 +22,7 @@ import ( ...@@ -22,6 +22,7 @@ import (
"crypto/x509" "crypto/x509"
"errors" "errors"
"fmt" "fmt"
"log"
"io" "io"
"code.vereign.com/code/viam-apis/data-storage-agent/client" "code.vereign.com/code/viam-apis/data-storage-agent/client"
...@@ -65,6 +66,7 @@ func newUUID() (string, error) { ...@@ -65,6 +66,7 @@ func newUUID() (string, error) {
func handlePutDataErrors(statusList []*api.Status, errors string, err error) []*api.Status { func handlePutDataErrors(statusList []*api.Status, errors string, err error) []*api.Status {
if err != nil { if err != nil {
log.Printf("Error: %v", err)
statusList = utils.AddStatus(statusList, "500", api.StatusType_ERROR, err.Error()) statusList = utils.AddStatus(statusList, "500", api.StatusType_ERROR, err.Error())
} else if errors != "" { } else if errors != "" {
statusList = utils.AddStatus(statusList, "400", api.StatusType_ERROR, errors) statusList = utils.AddStatus(statusList, "400", api.StatusType_ERROR, errors)
...@@ -76,11 +78,13 @@ func handlePutDataErrors(statusList []*api.Status, errors string, err error) []* ...@@ -76,11 +78,13 @@ func handlePutDataErrors(statusList []*api.Status, errors string, err error) []*
func readCertificateFromFile(fileName string) (*x509.Certificate, error) { func readCertificateFromFile(fileName string) (*x509.Certificate, error) {
certificatePemBlock, err := readPemBlockFromFile(fileName) certificatePemBlock, err := readPemBlockFromFile(fileName)
if err != nil { if err != nil {
log.Printf("Error: %v", err)
return nil, err return nil, err
} }
certificate, err := x509.ParseCertificate(certificatePemBlock.Bytes) certificate, err := x509.ParseCertificate(certificatePemBlock.Bytes)
if err != nil { if err != nil {
log.Printf("Error: %v", err)
return nil, err return nil, err
} }
...@@ -90,6 +94,7 @@ func readCertificateFromFile(fileName string) (*x509.Certificate, error) { ...@@ -90,6 +94,7 @@ func readCertificateFromFile(fileName string) (*x509.Certificate, error) {
func readPemBlockFromFile(fileName string) (*pem.Block, error) { func readPemBlockFromFile(fileName string) (*pem.Block, error) {
fileBytes, err := ioutil.ReadFile(fileName) fileBytes, err := ioutil.ReadFile(fileName)
if err != nil { if err != nil {
log.Printf("Error: %v", err)
return nil, err return nil, err
} }
......
...@@ -21,6 +21,9 @@ func SetConfigValues() { ...@@ -21,6 +21,9 @@ func SetConfigValues() {
viper.SetDefault("vereignCertFile", "vereign_ca.cer") viper.SetDefault("vereignCertFile", "vereign_ca.cer")
viper.SetDefault("vereignCertKey", "vereign_ca.key") viper.SetDefault("vereignCertKey", "vereign_ca.key")
viper.SetDefault("viamUUID", "viam-system")
viper.SetDefault("viamSession", "viam-session")
viper.SetDefault("maxMessageSize", "32") viper.SetDefault("maxMessageSize", "32")
// Read Config File // Read Config File
......
...@@ -67,8 +67,8 @@ func authenticateClient(ctx context.Context, s *handler.KeyStorageServerImpl, in ...@@ -67,8 +67,8 @@ func authenticateClient(ctx context.Context, s *handler.KeyStorageServerImpl, in
} }
viamAuth := &authentication.Authentication{ viamAuth := &authentication.Authentication{
Uuid: "viam-system", Uuid: viper.GetString("viamUUID"),
Session: "viam-session", Session: viper.GetString("viamSession"),
} }
sessionClient := &client.DataStorageClientImpl{} sessionClient := &client.DataStorageClientImpl{}
...@@ -95,12 +95,13 @@ func authenticateClient(ctx context.Context, s *handler.KeyStorageServerImpl, in ...@@ -95,12 +95,13 @@ func authenticateClient(ctx context.Context, s *handler.KeyStorageServerImpl, in
// unaryInterceptor call authenticateClient with current context // unaryInterceptor call authenticateClient with current context
func unaryInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler1 grpc.UnaryHandler) (interface{}, error) { func unaryInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler1 grpc.UnaryHandler) (interface{}, error) {
s, ok := info.Server.(*handler.KeyStorageServerImpl) s, ok := info.Server.(*handler.KeyStorageServerImpl)
fmt.Println("Invoked method: " + info.FullMethod) log.Println("Invoked method: " + info.FullMethod)
if !ok { if !ok {
return nil, fmt.Errorf("unable to cast server") return nil, fmt.Errorf("unable to cast server")
} }
clientID, err := authenticateClient(ctx, s, info.FullMethod) clientID, err := authenticateClient(ctx, s, info.FullMethod)
if err != nil { if err != nil {
log.Printf("Error: %v", err)
return nil, err return nil, err
} }
...@@ -117,6 +118,7 @@ func StartGRPCServer(address, certFilePath, privateKeyFilePath, caCertFilePath, ...@@ -117,6 +118,7 @@ func StartGRPCServer(address, certFilePath, privateKeyFilePath, caCertFilePath,
// create a listener on TCP port // create a listener on TCP port
lis, err := net.Listen("tcp", address) lis, err := net.Listen("tcp", address)
if err != nil { if err != nil {
log.Printf("Error: %v", err)
return fmt.Errorf("failed to listen: %v", err) return fmt.Errorf("failed to listen: %v", err)
} }
...@@ -134,6 +136,7 @@ func StartGRPCServer(address, certFilePath, privateKeyFilePath, caCertFilePath, ...@@ -134,6 +136,7 @@ func StartGRPCServer(address, certFilePath, privateKeyFilePath, caCertFilePath,
// Create the TLS credentials // Create the TLS credentials
creds, err := credentials.NewServerTLSFromFile(certFilePath, privateKeyFilePath) creds, err := credentials.NewServerTLSFromFile(certFilePath, privateKeyFilePath)
if err != nil { if err != nil {
log.Printf("Error: %v", err)
return fmt.Errorf("could not load TLS keys: %s", err) return fmt.Errorf("could not load TLS keys: %s", err)
} }
...@@ -168,6 +171,7 @@ func StartRESTServer(address, grpcAddress, certFile string) error { ...@@ -168,6 +171,7 @@ func StartRESTServer(address, grpcAddress, certFile string) error {
creds, err := credentials.NewClientTLSFromFile(certFile, "") creds, err := credentials.NewClientTLSFromFile(certFile, "")
if err != nil { if err != nil {
log.Printf("Error: %v", err)
return fmt.Errorf("could not load TLS certificate: %s", err) return fmt.Errorf("could not load TLS certificate: %s", err)
} }
...@@ -177,6 +181,7 @@ func StartRESTServer(address, grpcAddress, certFile string) error { ...@@ -177,6 +181,7 @@ func StartRESTServer(address, grpcAddress, certFile string) error {
// Register RedisStorageServer // Register RedisStorageServer
err = api.RegisterKeyStorageHandlerFromEndpoint(ctx, mux, grpcAddress, opts) err = api.RegisterKeyStorageHandlerFromEndpoint(ctx, mux, grpcAddress, opts)
if err != nil { if err != nil {
log.Printf("Error: %v", err)
return fmt.Errorf("could not register service RedisStorageServer: %s", err) return fmt.Errorf("could not register service RedisStorageServer: %s", err)
} }
......
...@@ -18,12 +18,14 @@ along with this program. If not, see <http://www.gnu.org/licenses/>. ...@@ -18,12 +18,14 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
package session package session
import ( import (
"log"
client "code.vereign.com/code/viam-apis/data-storage-agent/client" client "code.vereign.com/code/viam-apis/data-storage-agent/client"
) )
func CheckSession(uuid string, session string, sessionClient *client.DataStorageClientImpl) bool { func CheckSession(uuid string, session string, sessionClient *client.DataStorageClientImpl) bool {
hasSession, _, err := sessionClient.HasSession(uuid, session) hasSession, _, err := sessionClient.HasSession(uuid, session)
if err != nil { if err != nil {
log.Printf("Error: %v", err)
return false return false
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment