#include <vereign/crypto/cert.hh>

#include <vereign/crypto/bio.hh>
#include <vereign/crypto/errors.hh>
#include <vereign/encoding/base64.hh>

#include <openssl/x509v3.h>
#include <openssl/digest.h>
#include <openssl/pem.h>
#include <sstream>

namespace {
  constexpr const char* certDefaultHashAlg = "SHA256";
  constexpr const int certVersion = 2;
}

namespace vereign::crypto::cert {

static auto addCertExtension(
  const X509* issuer_cert,
  X509* cert,
  int nid,
  const std::string& value
) -> int {
  X509V3_CTX ctx;

  // This sets the 'context' of the extensions.
  // No configuration database
  X509V3_set_ctx_nodb(&ctx);

  // Issuer and subject certs: both the target since it is self signed,
  // no request and no CRL
  X509V3_set_ctx(&ctx, const_cast<X509*>(issuer_cert), cert, nullptr, nullptr, 0);
  bssl::UniquePtr<X509_EXTENSION> ex{X509V3_EXT_nconf_nid(
    nullptr,
    &ctx,
    nid,
    const_cast<std::string&>(value).data()
  )};
  if (!ex) {
    return 0;
  }

  return X509_add_ext(cert, ex.get(), -1);
}

static void addSubjAltNameExt(X509* cert, const std::string& email, const std::string& url) {
  auto gens = bssl::UniquePtr<GENERAL_NAMES>(sk_GENERAL_NAME_new_null());
  if (!gens) {
    throw OpenSSLError("creating GENERAL_NAMES stack failed");
  }

  if (!email.empty()) {
    auto gen = bssl::UniquePtr<GENERAL_NAME>(GENERAL_NAME_new());
    if (!gen) {
      throw OpenSSLError("creating GENERAL_NAME failed");
    }

    auto ia5 = bssl::UniquePtr<ASN1_IA5STRING>(ASN1_IA5STRING_new());
    if (!ia5) {
      throw OpenSSLError("creating ASN1_IA5STRING failed");
    }

    auto r = ASN1_STRING_set(ia5.get(), email.data(), -1);
    if (r != 1) {
      throw OpenSSLError("set certificate alternative name email part failed");
    }

    GENERAL_NAME_set0_value(gen.get(), GEN_EMAIL, ia5.release());
    r = sk_GENERAL_NAME_push(gens.get(), gen.release());
    if (r == 0) {
      throw OpenSSLError("pushing email to certificate subject alternative name failed");
    }
  }

  if (!url.empty()) {
    auto gen = bssl::UniquePtr<GENERAL_NAME>(GENERAL_NAME_new());
    if (!gen) {
      throw OpenSSLError("creating GENERAL_NAME failed");
    }

    auto ia5 = bssl::UniquePtr<ASN1_IA5STRING>(ASN1_IA5STRING_new());
    if (!ia5) {
      throw OpenSSLError("creating ASN1_IA5STRING failed");
    }

    auto r = ASN1_STRING_set(ia5.get(), url.data(), -1);
    if (r != 1) {
      throw OpenSSLError("set certificate alternative name URL part failed");
    }

    GENERAL_NAME_set0_value(gen.get(), GEN_DNS, ia5.release());
    r = sk_GENERAL_NAME_push(gens.get(), gen.release());
    if (r == 0) {
      throw OpenSSLError("pushing URL to certificate subject alternative name failed");
    }
  }

  auto r = X509_add1_ext_i2d(cert, NID_subject_alt_name, gens.get(), 0, 0);
  if (r != 1) {
    throw OpenSSLError("set certificate alternative name failed");
  }
}

static void setCertSubject(const CertData& cert_data, X509* cert) {
  const CertSubject& subject = cert_data.Subject;
  auto subject_name = X509_get_subject_name(cert);
  int r = 0;

  if (!subject.CommonName.empty()) {
    r = X509_NAME_add_entry_by_NID(
      subject_name,
      NID_commonName,
      MBSTRING_ASC,
      (unsigned char*) subject.CommonName.data(),
      -1,
      -1,
      0
    );
    if (r != 1) {
      throw OpenSSLError("set certificate subject Common Name failed");
    }
  }

  if (!subject.Country.empty()) {
    r = X509_NAME_add_entry_by_NID(
      subject_name,
      NID_countryName,
      MBSTRING_ASC,
      (unsigned char*) subject.Country.data(),
      -1,
      -1,
      0
    );
    if (r != 1) {
      throw OpenSSLError("set certificate subject Country Name failed");
    }
  }

  if (!subject.Locality.empty()) {
    r = X509_NAME_add_entry_by_NID(
      subject_name,
      NID_localityName,
      MBSTRING_ASC,
      (unsigned char*) subject.Locality.data(),
      -1,
      -1,
      0
    );
    if (r != 1) {
      throw OpenSSLError("set certificate subject Locality Name failed");
    }
  }

  if (!subject.State.empty()) {
    r = X509_NAME_add_entry_by_NID(
      subject_name,
      NID_stateOrProvinceName,
      MBSTRING_ASC,
      (unsigned char*) subject.State.data(),
      -1,
      -1,
      0
    );
    if (r != 1) {
      throw OpenSSLError("set certificate subject State Name failed");
    }
  }

  if (!subject.Organization.empty()) {
    r = X509_NAME_add_entry_by_NID(
      subject_name,
      NID_organizationName,
      MBSTRING_ASC,
      (unsigned char*) subject.Organization.data(),
      -1,
      -1,
      0
    );
    if (r != 1) {
      throw OpenSSLError("set certificate subject Organization Name failed");
    }
  }

  if (!subject.OrganizationUnit.empty()) {
    r = X509_NAME_add_entry_by_NID(
      subject_name,
      NID_organizationalUnitName,
      MBSTRING_ASC,
      (unsigned char*) subject.OrganizationUnit.data(),
      -1,
      -1,
      0
    );
    if (r != 1) {
      throw OpenSSLError("set certificate subject Organizational Unit Name failed");
    }
  }


  if (!cert_data.Email.empty()) {
    r = X509_NAME_add_entry_by_NID(
      subject_name,
      NID_pkcs9_emailAddress,
      MBSTRING_ASC,
      (unsigned char*) cert_data.Email.data(),
      -1,
      -1,
      0
    );
    if (r != 1) {
      throw OpenSSLError("set certificate subject Email Address failed");
    }
  }
}

static void addCertExtensions(const CertData& cert_data, const X509* issuer_cert, X509* cert) {
  // basic constraints
  const std::string basic_constraints = cert_data.IsCA ? "critical,CA:TRUE" : "critical,CA:FALSE";
  auto r = addCertExtension(issuer_cert, cert, NID_basic_constraints, basic_constraints);
  if (r != 1) {
    throw OpenSSLError("set certificate basic constraints failed");
  }

  // key usage
  const std::string key_usage = cert_data.IsCA
    ? "critical,digitalSignature,keyCertSign,cRLSign"
    : "critical,digitalSignature,nonRepudiation,keyEncipherment,dataEncipherment";
  r = addCertExtension(issuer_cert, cert, NID_key_usage, key_usage);
  if (r != 1) {
    throw OpenSSLError("set certificate key usage failed");
  }

  // extended key usage
  if (!cert_data.IsCA && !cert_data.Email.empty()) {
    r = addCertExtension(issuer_cert, cert, NID_ext_key_usage, "emailProtection");
    if (r != 1) {
      throw OpenSSLError("set certificate extended key usage failed");
    }
  }

  // subject alternative name
  if (!cert_data.Email.empty() || !cert_data.Url.empty()) {
    addSubjAltNameExt(cert, cert_data.Email, cert_data.Url);
  }

  // subject key identifier
  r = addCertExtension(issuer_cert, cert, NID_subject_key_identifier, "hash");
  if (r != 1) {
    throw OpenSSLError("set certificate key usage failed");
  }

  // authority key identifier
  r = addCertExtension(issuer_cert, cert, NID_authority_key_identifier, "keyid,issuer:always");
  if (r != 1) {
    throw OpenSSLError("set certificate key usage failed");
  }
}

static void setCertValidity(const CertData& cert_data, X509* cert) {
  using clock = std::chrono::system_clock;

  long not_before_adj = 0;
  time_t not_before = cert_data.Validity.NotBefore;

  if (not_before == 0) {
    auto now = clock::now();
    not_before = std::chrono::duration_cast<std::chrono::seconds>(now.time_since_epoch()).count();
    not_before_adj = -(not_before % (24 * 60 * 60));
  }

  auto timeret = X509_time_adj(X509_get_notBefore(cert), not_before_adj, &not_before);
  if (timeret == nullptr) {
    throw OpenSSLError("adjusting certificate not before validity failed");
  }

  long not_after_adj = 0;
  time_t not_after = cert_data.Validity.NotAfter;

  if (not_after == 0) {
    not_after = not_before;

    int valid_years = cert_data.Validity.ValidYears;
    if (valid_years == 0) {
      valid_years = 1;
    }

    not_after_adj = not_before_adj + valid_years * 365 * 24 * 60 * 60;
  }

  timeret = X509_time_adj(X509_get_notAfter(cert), not_after_adj, &not_after);
  if (timeret == nullptr) {
    throw OpenSSLError("adjusting certificate not after validity failed");
  }
}

static auto createCert(
  const CertData& cert_data,
  const X509* issuer_cert,
  const EVP_PKEY* issuer_pkey,
  const EVP_PKEY* pkey
) -> bssl::UniquePtr<X509> {
  const EVP_MD* hash_alg = nullptr;
  if (cert_data.Algorithms.HashAlg.empty()) {
    hash_alg = EVP_get_digestbyname(certDefaultHashAlg);
  } else {
    hash_alg = EVP_get_digestbyname(cert_data.Algorithms.HashAlg.data());
  }

  if (hash_alg == nullptr) {
    throw crypto::OpenSSLError("cannot find hash algorithm");
  }

  auto cert = bssl::UniquePtr<X509>(X509_new());

  if (issuer_cert == nullptr || issuer_pkey == nullptr) {
    issuer_cert = cert.get();
    issuer_pkey = pkey;
  }

  // set public key
  auto r = X509_set_pubkey(cert.get(), const_cast<EVP_PKEY*>(pkey));
  if (r != 1) {
    throw OpenSSLError("set public key to certificate failed");
  }

  // set certificate version
  r = X509_set_version(cert.get(), certVersion);
  if (r != 1) {
    throw OpenSSLError("set version to certificate failed");
  }

  // set serial number
  auto serial_number = cert_data.SerialNumber;
  if (serial_number == 0) {
    // FIXME: is using time ok ?
    serial_number = std::chrono::duration_cast<std::chrono::milliseconds>(
      std::chrono::system_clock::now().time_since_epoch()
    ).count();
  }
  r = ASN1_INTEGER_set_uint64(X509_get_serialNumber(cert.get()), serial_number);
  if (r != 1) {
    throw OpenSSLError("set certificate serial number");
  }

  // set subject
  setCertSubject(cert_data, cert.get());

  // set issuer name
  auto issuer_name = X509_get_subject_name(const_cast<X509*>(issuer_cert));
  r = X509_set_issuer_name(cert.get(), issuer_name);
  if (r != 1) {
    throw OpenSSLError("set certificate issuer name failed");
  }

  // set validity
  setCertValidity(cert_data, cert.get());

  // add extensions
  addCertExtensions(cert_data, issuer_cert, cert.get());

  r = X509_sign(cert.get(), const_cast<EVP_PKEY*>(issuer_pkey), hash_alg);
  if (r == 0) {
    throw OpenSSLError("certificate sign failed");
  }

  return cert;
}

auto CreateCert(
  const CertData& cert_data,
  const X509* issuer_cert,
  const EVP_PKEY* issuer_pkey,
  const EVP_PKEY* pkey
) -> bssl::UniquePtr<X509> {

  return createCert(cert_data, issuer_cert, issuer_pkey, pkey);
}

auto CreateSelfSignedCert(
  const CertData& cert_data,
  const EVP_PKEY* pkey
) -> bssl::UniquePtr<X509> {

  return createCert(cert_data, nullptr, nullptr, pkey);
}

auto ExportCertToPEM(const X509* cert) -> bssl::UniquePtr<BIO> {
  bssl::UniquePtr<BIO> mem(BIO_new(BIO_s_mem()));
  if (!mem) {
    throw OpenSSLError("creating memory buffer failed");
  }

  auto r = PEM_write_bio_X509(mem.get(), const_cast<X509*>(cert));
  if (r != 1) {
    throw OpenSSLError("exporting certificate to PEM failed");
  }

  return mem;
}

auto ImportCertFromPEM(bytes::View pem) -> bssl::UniquePtr<X509> {
  bssl::UniquePtr<BIO> mem(BIO_new_mem_buf(pem.Data(), pem.Size()));
  if (mem == nullptr) {
    throw OpenSSLError("creating memory buffer failed");
  }

  auto cert = bssl::UniquePtr<X509>(PEM_read_bio_X509(mem.get(), nullptr, nullptr, nullptr));
  if (cert == nullptr) {
    throw OpenSSLError("importing certificate from PEM failed");
  }

  return cert;
}

void PrintCert(std::ostream& os, const X509* cert) {
  bssl::UniquePtr<BIO> mem(BIO_new(BIO_s_mem()));
  if (!mem) {
    throw OpenSSLError("creating memory buffer failed");
  }

  auto r = X509_print(mem.get(), const_cast<X509*>(cert));
  if (r != 1) {
    throw OpenSSLError("printing certificate failed");
  }

  os << bio::View(mem.get()).String();
}

} // vereign::crypto::cert