/*
Copyright (c) 2018 Vereign AG [https://www.vereign.com]

This is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as
published by the Free Software Foundation, either version 3 of the
License, or (at your option) any later version.

This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU Affero General Public License for more details.

You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/

package handler

import (
	"code.vereign.com/code/key-storage-agent/config"
	"code.vereign.com/code/viam-apis/errors"
	"crypto/aes"
	"crypto/cipher"
	"crypto/rand"
	"crypto/rsa"
	"crypto/sha256"
	"crypto/x509"
	"encoding/pem"

	"code.vereign.com/code/viam-apis/log"

	keyutils "code.vereign.com/code/key-storage-agent/utils"
	"code.vereign.com/code/viam-apis/key-storage-agent/api"
	"golang.org/x/net/context"
)

func (s *KeyStorageServerImpl) GenerateKeyPair(ctx context.Context,
	in *api.GenerateKeyPairRequest) (*api.GenerateKeyPairResponse, error) {

	auth := s.CreateAuthentication(ctx)

	client := keyutils.CreateDataStorageClient(auth)
	defer client.CloseClient()

	uuid, err := keyutils.GenerateUnusedUUID(client)
	if err != nil {
		log.Printf("Error: %v", err)
		return nil, err
	}

	privateKeyBytes, publicKeyBytes, err := generateKeyPair(int(in.KeySize))
	if err != nil {
		log.Printf("Error: %v", err)
		return nil, err
	}

	aesKeyBytes, err := generateRandomSequence(256)
	if err != nil {
		log.Printf("Error: %v", err)
		return nil, err
	}

	encryptedPrivateKeyBytes, privateKeyNonce, err := aesEncrypt(aesKeyBytes, privateKeyBytes)
	if err != nil {
		log.Printf("Error: %v", err)
		return nil, err
	}
	encryptedPrivateKey := &api.Key{Content: encryptedPrivateKeyBytes}
	_, _, err = client.PutData("keys", uuid+"/"+api.KeyType.String(api.KeyType_PRIVATE), encryptedPrivateKey)
	if err != nil {
		return nil, errors.WrapInternalFormat(err, "Could not store key %s", uuid+"/"+api.KeyType.String(api.KeyType_PRIVATE))
	}


	publicKey := &api.Key{Content: publicKeyBytes}
	_, _, err = client.PutData("keys", uuid+"/"+api.KeyType.String(api.KeyType_PUBLIC), publicKey)
	if err != nil {
		return nil, errors.WrapInternalFormat(err, "Could not store key %s", uuid+"/"+api.KeyType.String(api.KeyType_PUBLIC))
	}

	//duplicate logic of ReserveKeyUUID

	emptyKey := &api.Key{Content: []byte{}}
	_, _, err = client.PutData("keys", uuid+"/"+api.KeyType.String(api.KeyType_CERTIFICATE), emptyKey)
	if err != nil {
		return nil, errors.WrapInternalFormat(err, "Could not store key %s", uuid+"/"+api.KeyType.String(api.KeyType_CERTIFICATE))
	}

	encryptedAesKeyBytes, err := rsaEncryptWithServerKey(aesKeyBytes, []byte("aeskeys"))
	if err != nil {
		return nil, errors.WrapInternal(err, "Could not encrypt")
	}

	encryptedAesKey := &api.Key{Content: encryptedAesKeyBytes}

	_, _, err = client.PutData("keys", uuid+"/"+api.KeyType.String(api.KeyType_AES), encryptedAesKey)
	if err != nil {
		return nil, errors.WrapInternalFormat(err, "Could not store key %s", uuid+"/"+api.KeyType.String(api.KeyType_AES))
	}

	encryptedPrivateKeyNonceBytes, err := rsaEncryptWithServerKey(privateKeyNonce, []byte("nonce"))
	if err != nil {
		return nil, errors.WrapInternal(err, "Could not encrypt private key")
	}

	encryptedNonce := &api.Key{Content: encryptedPrivateKeyNonceBytes}

	_, _, err = client.PutData("keys", uuid+"/"+api.KeyType.String(api.KeyType_NONCE), encryptedNonce)
	if err != nil {
		return nil, errors.WrapInternalFormat(err, "Could not store key %s", uuid+"/"+api.KeyType.String(api.KeyType_NONCE))
	}

	generateKeyPairResponse := &api.GenerateKeyPairResponse{Uuid:uuid}
	return generateKeyPairResponse, nil
}

func generateKeyPair(keySize int) ([]byte, []byte, error) {
	privateKey, err := rsa.GenerateKey(rand.Reader, keySize)
	if err != nil {
		log.Printf("Error: %v", err)
		return nil, nil, err
	}

	err = privateKey.Validate()
	if err != nil {
		log.Printf("Error: %v", err)
		return nil, nil, err
	}

	publicKey := &privateKey.PublicKey

	pkcs8PrivateKeyBytes, err := x509.MarshalPKCS8PrivateKey(privateKey)
	if err != nil {
		log.Printf("Error: %v", err)
		return nil, nil, err
	}

	privateKeyPemBlock := &pem.Block{
		Type:  "PRIVATE KEY",
		Bytes: pkcs8PrivateKeyBytes,
	}

	privateKeyBytes := pem.EncodeToMemory(privateKeyPemBlock)

	pkixPublicKeyBytes, err := x509.MarshalPKIXPublicKey(publicKey)
	if err != nil {
		log.Printf("Error: %v", err)
		return nil, nil, err
	}

	publicKeyPemBlock := &pem.Block{
		Type:  "PUBLIC KEY",
		Bytes: pkixPublicKeyBytes,
	}

	publicKeyBytes := pem.EncodeToMemory(publicKeyPemBlock)

	return privateKeyBytes, publicKeyBytes, nil
}

func rsaEncryptWithServerKey(message []byte, label []byte) ([]byte, error) {
	serverPublicKey := config.EncryptionCert.PublicKey.(*rsa.PublicKey)

	encryptedMessageBytes, err := rsa.EncryptOAEP(sha256.New(), rand.Reader, serverPublicKey, message, label)
	if err != nil {
		log.Printf("Error: %v", err)
		return nil, err
	}

	return encryptedMessageBytes, nil
}

func aesEncrypt(aesKey []byte, message []byte) ([]byte, []byte, error) {
	block, err := aes.NewCipher(aesKey)
	if err != nil {
		log.Printf("Error: %v", err)
		return nil, nil, err
	}

	aesgcm, err := cipher.NewGCM(block)
	if err != nil {
		log.Printf("Error: %v", err)
		return nil, nil, err
	}

	nonce, err := generateRandomSequence(aesgcm.NonceSize() * 8)
	if err != nil {
		log.Printf("Error: %v", err)
		return nil, nil, err
	}

	encryptedMessage := aesgcm.Seal(nil, nonce, message, nil)

	return encryptedMessage, nonce, nil
}

func generateRandomSequence(keySize int) ([]byte, error) {
	key := make([]byte, keySize/8)

	_, err := rand.Read(key)
	if err != nil {
		log.Printf("Error: %v", err)
		return nil, err
	}

	return key, nil
}