Skip to content
Snippets Groups Projects
generate_keypair.go 7.15 KiB
Newer Older
  • Learn to ignore specific revisions
  • /*
    Copyright (c) 2018 Vereign AG [https://www.vereign.com]
    
    This is free software: you can redistribute it and/or modify
    it under the terms of the GNU Affero General Public License as
    published by the Free Software Foundation, either version 3 of the
    License, or (at your option) any later version.
    
    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU Affero General Public License for more details.
    
    You should have received a copy of the GNU Affero General Public License
    along with this program. If not, see <http://www.gnu.org/licenses/>.
    */
    
    package handler
    
    import (
    	"crypto/aes"
    	"crypto/cipher"
    	"crypto/rand"
    	"crypto/rsa"
    	"crypto/sha256"
    	"crypto/x509"
    
    	"encoding/pem"
    
    
    	"code.vereign.com/code/viam-apis/data-storage-agent/client"
    	"code.vereign.com/code/viam-apis/key-storage-agent/api"
    	"code.vereign.com/code/viam-apis/utils"
    	"code.vereign.com/code/viam-apis/versions"
    	"golang.org/x/net/context"
    )
    
    
    func (s *KeyStorageServerImpl) GenerateKeyPair(ctx context.Context,
    	in *api.GenerateKeyPairRequest) (*api.GenerateKeyPairResponse, error) {
    
    
    	auth := s.CreateAuthentication(ctx)
    
    	client := &client.DataStorageClientImpl{}
    
    	client.SetUpClient(auth, s.DataStorageUrl, s.CertFilePath, s.KeyFilePath, s.CaCertFilePath)
    
    	defer client.CloseClient()
    
    	generateKeyPairResponse := &api.GenerateKeyPairResponse{}
    
    	uuid, err := generateUnusedUUID(client)
    	if err != nil {
    		generateKeyPairResponse.StatusList = utils.AddStatus(generateKeyPairResponse.StatusList,
    			"500", api.StatusType_ERROR, err.Error())
    	}
    
    	privateKeyBytes, publicKeyBytes, err := generateKeyPair(int(in.KeySize))
    	if err != nil {
    		generateKeyPairResponse.StatusList = utils.AddStatus(generateKeyPairResponse.StatusList,
    			"500", api.StatusType_ERROR, err.Error())
    	}
    
    	aesKeyBytes, err := generateRandomSequence(256)
    	if err != nil {
    		generateKeyPairResponse.StatusList = utils.AddStatus(generateKeyPairResponse.StatusList,
    			"500", api.StatusType_ERROR, err.Error())
    	}
    
    
    	encryptedPrivateKeyBytes, privateKeyNonce, err := aesEncrypt(aesKeyBytes, privateKeyBytes)
    
    	if err != nil {
    		generateKeyPairResponse.StatusList = utils.AddStatus(generateKeyPairResponse.StatusList,
    			"500", api.StatusType_ERROR, err.Error())
    		return generateKeyPairResponse, nil
    	}
    
    	encryptedPrivateKey := &api.Key{Content: encryptedPrivateKeyBytes}
    
    	result, errors, err := client.DoPutDataCall("keys", uuid+"/"+api.KeyType.String(api.KeyType_PRIVATE),
    		encryptedPrivateKey, versions.EntitiesManagementAgentApiVersion)
    
    	generateKeyPairResponse.StatusList = handlePutDataErrors(generateKeyPairResponse.StatusList, errors, err)
    
    	if generateKeyPairResponse.StatusList == nil || len(generateKeyPairResponse.StatusList) == 0 {
    
    		publicKey := &api.Key{Content: publicKeyBytes}
    		result, errors, err = client.DoPutDataCall("keys", uuid+"/"+api.KeyType.String(api.KeyType_PUBLIC),
    			publicKey, versions.EntitiesManagementAgentApiVersion)
    		generateKeyPairResponse.StatusList = handlePutDataErrors(generateKeyPairResponse.StatusList, errors, err)
    	}
    
    
    	//duplicate logic of ReserveKeyUUID
    	if generateKeyPairResponse.StatusList == nil || len(generateKeyPairResponse.StatusList) == 0 {
    		emptyKey := &api.Key{Content: []byte{}}
    		result, errors, err = client.DoPutDataCall("keys", uuid+"/"+api.KeyType.String(api.KeyType_CERTIFICATE),
    			emptyKey, versions.EntitiesManagementAgentApiVersion)
    		generateKeyPairResponse.StatusList = handlePutDataErrors(generateKeyPairResponse.StatusList, errors, err)
    	}
    
    
    	if generateKeyPairResponse.StatusList == nil || len(generateKeyPairResponse.StatusList) == 0 {
    		encryptedAesKeyBytes, err := rsaEncryptWithServerKey(s.VereignCertFilePath, aesKeyBytes, []byte("aeskeys"))
    
    		if err != nil {
    			generateKeyPairResponse.StatusList = utils.AddStatus(generateKeyPairResponse.StatusList,
    				"500", api.StatusType_ERROR, err.Error())
    			return generateKeyPairResponse, nil
    		}
    
    
    		encryptedAesKey := &api.Key{Content: encryptedAesKeyBytes}
    
    
    		result, errors, err = client.DoPutDataCall("keys", uuid+"/"+api.KeyType.String(api.KeyType_AES), encryptedAesKey, versions.EntitiesManagementAgentApiVersion)
    
    		generateKeyPairResponse.StatusList = handlePutDataErrors(generateKeyPairResponse.StatusList, errors, err)
    
    	if generateKeyPairResponse.StatusList == nil || len(generateKeyPairResponse.StatusList) == 0 {
    		encryptedPrivateKeyNonceBytes, err := rsaEncryptWithServerKey(s.VereignCertFilePath, privateKeyNonce, []byte("nonce"))
    		if err != nil {
    			generateKeyPairResponse.StatusList = utils.AddStatus(generateKeyPairResponse.StatusList,
    				"500", api.StatusType_ERROR, err.Error())
    			return generateKeyPairResponse, nil
    		}
    
    		encryptedNonce := &api.Key{Content: encryptedPrivateKeyNonceBytes}
    
    		result, errors, err = client.DoPutDataCall("keys", uuid+"/"+api.KeyType.String(api.KeyType_NONCE), encryptedNonce, versions.EntitiesManagementAgentApiVersion)
    		generateKeyPairResponse.StatusList = handlePutDataErrors(generateKeyPairResponse.StatusList, errors, err)
    
    	}
    
    	if generateKeyPairResponse.StatusList == nil || len(generateKeyPairResponse.StatusList) == 0 {
    		generateKeyPairResponse.Uuid = uuid
    		generateKeyPairResponse.StatusList = utils.AddStatus(generateKeyPairResponse.StatusList,
    			"200", api.StatusType_INFO, result)
    	}
    
    	return generateKeyPairResponse, nil
    }
    
    func generateKeyPair(keySize int) ([]byte, []byte, error) {
    	privateKey, err := rsa.GenerateKey(rand.Reader, keySize)
    	if err != nil {
    		return nil, nil, err
    	}
    
    	err = privateKey.Validate()
    	if err != nil {
    		return nil, nil, err
    	}
    
    	publicKey := &privateKey.PublicKey
    
    
    	pkcs8PrivateKeyBytes, err := x509.MarshalPKCS8PrivateKey(privateKey)
    
    
    	privateKeyPemBlock := &pem.Block{
    
    		Type:  "PRIVATE KEY",
    
    		Bytes: pkcs8PrivateKeyBytes,
    	}
    
    	privateKeyBytes := pem.EncodeToMemory(privateKeyPemBlock)
    
    	pkixPublicKeyBytes, err := x509.MarshalPKIXPublicKey(publicKey)
    
    	publicKeyPemBlock := &pem.Block{
    
    		Type:  "PUBLIC KEY",
    
    		Bytes: pkixPublicKeyBytes,
    	}
    
    	publicKeyBytes := pem.EncodeToMemory(publicKeyPemBlock)
    
    
    	return privateKeyBytes, publicKeyBytes, nil
    }
    
    func rsaEncryptWithServerKey(certFilePath string, message []byte, label []byte) ([]byte, error) {
    	serverCertificate, err := readCertificateFromFile(certFilePath)
    	if err != nil {
    		return nil, err
    	}
    	serverPublicKey := serverCertificate.PublicKey.(*rsa.PublicKey)
    
    	encryptedMessageBytes, err := rsa.EncryptOAEP(sha256.New(), rand.Reader, serverPublicKey, message, label)
    	if err != nil {
    		return nil, err
    	}
    
    	return encryptedMessageBytes, nil
    }
    
    func aesEncrypt(aesKey []byte, message []byte) ([]byte, []byte, error) {
    	block, err := aes.NewCipher(aesKey)
    	if err != nil {
    		return nil, nil, err
    	}
    
    	aesgcm, err := cipher.NewGCM(block)
    	if err != nil {
    		return nil, nil, err
    	}
    
    	nonce, err := generateRandomSequence(aesgcm.NonceSize() * 8)
    	if err != nil {
    		return nil, nil, err
    	}
    
    	encryptedMessage := aesgcm.Seal(nil, nonce, message, nil)
    
    	return encryptedMessage, nonce, nil
    }
    
    func generateRandomSequence(keySize int) ([]byte, error) {
    	key := make([]byte, keySize/8)
    
    	_, err := rand.Read(key)
    	if err != nil {
    		return nil, err
    	}
    
    	return key, nil
    }