Newer
Older
package server
/*
Copyright (c) 2018 Vereign AG [https://www.vereign.com]
This is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as
published by the Free Software Foundation, either version 3 of the
License, or (at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
import (
"sync"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promhttp"
"code.vereign.com/code/key-storage-agent/config"
"code.vereign.com/code/key-storage-agent/handler"
"code.vereign.com/code/key-storage-agent/session"
"code.vereign.com/code/key-storage-agent/utils"
"code.vereign.com/code/viam-apis/authentication"
api "code.vereign.com/code/viam-apis/key-storage-agent/api"
"github.com/grpc-ecosystem/grpc-gateway/runtime"
"golang.org/x/net/context"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/metadata"
)
var mutex sync.RWMutex
var summaries = make(map[string]prometheus.Summary)
// private type for Context keys
type contextKey int
const (
clientIDKey contextKey = iota
)
var pkgCertPEM []byte
var pkgKeyPEM []byte
var pkgCaCertPEM []byte
func credMatcher(headerName string) (mdName string, ok bool) {
if headerName == "Session" {
return headerName, true
}
return "", false
}
// authenticateAgent check the client credentials
func authenticateClient(ctx context.Context, s *handler.KeyStorageServerImpl, invokedMethod string) (string, error) {
if md, ok := metadata.FromIncomingContext(ctx); ok {
clientAuth := &authentication.Authentication{
Uuid: strings.Join(md["uuid"], ""),
Session: strings.Join(md["session"], ""),
}
viamAuth := &authentication.Authentication{
Uuid: config.SystemAuth.Uuid,
Session: config.SystemAuth.Session,
sessionClient := utils.CreateDataStorageClient(viamAuth)
defer sessionClient.CloseClient()
if clientAuth.Uuid == viamAuth.Uuid {
if clientAuth.Session != viamAuth.Session {
return "", fmt.Errorf("bad session %s", clientAuth.Session)
}
} else {
if session.CheckSession(clientAuth.Uuid, clientAuth.Session, sessionClient) == false {
return "", fmt.Errorf("bad session %s", clientAuth.Session)
}
}
return clientAuth.Uuid, nil
}
return "", fmt.Errorf("missing credentials")
}
// unaryInterceptor call authenticateClient with current context
func unaryInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler1 grpc.UnaryHandler) (interface{}, error) {
s, ok := info.Server.(*handler.KeyStorageServerImpl)
if !ok {
return nil, fmt.Errorf("unable to cast server")
}
clientID, err := authenticateClient(ctx, s, info.FullMethod)
if err != nil {
return nil, err
}
ctx = context.WithValue(ctx, clientIDKey, clientID)
res, err := handler1(ctx, req)
if err != nil {
errors.Log(err)
}
return res, err
func createQueryTime(funcName string) prometheus.Summary {
endPointName := strings.Replace(funcName, "/", "_", -1)
endPointName = strings.Replace(endPointName, ".", "_", -1)
metricName := endPointName + "_grpc_request_duration_seconds"
if config.Config.MetricEnvironmentPrefix != "" {
metricName = config.Config.MetricEnvironmentPrefix + "_" + metricName
}
mutex.Lock()
defer mutex.Unlock()
queryTime, ok := summaries[metricName]
if ok == false {
queryTime = prometheus.NewSummary(prometheus.SummaryOpts{
Name: metricName,
Help: "grpc request duration seconds of /" + funcName + " for " + config.Config.MetricEnvironmentPrefix + " env",
})
// init metrics
prometheus.MustRegister(queryTime)
summaries[metricName] = queryTime
}
return queryTime
}
return nil
}
func StartGRPCServer(address string, certPEM, privateKeyPEM, caCertPEM, vereignCertPEM []byte, dataStorageAddress string, maxMessageSize int) error {
pkgCertPEM = certPEM
pkgKeyPEM = privateKeyPEM
pkgCaCertPEM = caCertPEM
// create a listener on TCP port
lis, err := net.Listen("tcp", address)
if err != nil {
return fmt.Errorf("failed to listen: %v", err)
}
// create a server instance
s := handler.KeyStorageServerImpl{
DataStorageUrl: dataStorageAddress,
CertPEM: certPEM,
KeyPEM: privateKeyPEM,
CaCertPEM: caCertPEM,
VereignCertPEM: vereignCertPEM,
MaxMessageSize: maxMessageSize,
// Load the certificates from PEM Strings
certificate, err := tls.X509KeyPair(certPEM, privateKeyPEM)
return fmt.Errorf("could not load server key pair: %s", err)
}
// Create a certificate pool from the certificate authority
// Get the SystemCertPool, continue with an empty pool on error
certPool, _ := x509.SystemCertPool()
if certPool == nil {
certPool = x509.NewCertPool()
if ok := certPool.AppendCertsFromPEM(caCertPEM); !ok {
return fmt.Errorf("failed to append server certs")
}
// Create the TLS credentials
creds := credentials.NewTLS(&tls.Config{
//ClientAuth: tls.RequireAndVerifyClientCert,
Certificates: []tls.Certificate{certificate},
ClientCAs: certPool,
})
// Create an array of gRPC options with the credentials
grpc.UnaryInterceptor(unaryInterceptor),
grpc.MaxRecvMsgSize(config.Config.MaxMessageSize * 1024 * 1024),
// create a gRPC server object
grpcServer := grpc.NewServer(opts...)
// attach the CalcMinimumDistance service to the server
api.RegisterKeyStorageServer(grpcServer, &s)
// start the server
log.Printf("starting HTTP/2 gRPC server on %s", address)
if err := grpcServer.Serve(lis); err != nil {
return fmt.Errorf("failed to serve: %s", err)
}
return nil
}
func StartRESTServer(address, grpcAddress string, certPEM []byte) error {
ctx := context.Background()
ctx, cancel := context.WithCancel(ctx)
defer cancel()
mux := runtime.NewServeMux(runtime.WithIncomingHeaderMatcher(credMatcher))
certPool, err := x509.SystemCertPool()
if certPool == nil {
certPool = x509.NewCertPool()
// Append the client certificates from the CA
if ok := certPool.AppendCertsFromPEM(certPEM); !ok {
return fmt.Errorf("failed to append client certs")
}
creds := credentials.NewClientTLSFromCert(certPool, "")
// Setup the client gRPC options
opts := []grpc.DialOption{grpc.WithTransportCredentials(creds)}
// Register RedisStorageServer
err = api.RegisterKeyStorageHandlerFromEndpoint(ctx, mux, grpcAddress, opts)
if err != nil {
return fmt.Errorf("could not register service RedisStorageServer: %s", err)
}
log.Printf("starting HTTP/1.1 REST server on %s", address)
http.ListenAndServe(address, mux)
return nil
}
// start prometheus
promHandler := http.NewServeMux()
promHandler.Handle("/metrics", promhttp.Handler())
log.Println("Starting prometheus...")
err := http.ListenAndServe(config.Config.PrometeusListenAddress, promHandler)
if err != nil {
return err
}
}
return nil
}