package server /* Copyright (c) 2018 Vereign AG [https://www.vereign.com] This is free software: you can redistribute it and/or modify it under the terms of the GNU Affero General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Affero General Public License for more details. You should have received a copy of the GNU Affero General Public License along with this program. If not, see <http://www.gnu.org/licenses/>. */ import ( "crypto/tls" "crypto/x509" "fmt" "log" "net" "net/http" "strings" "code.vereign.com/code/key-storage-agent/config" "code.vereign.com/code/key-storage-agent/handler" "code.vereign.com/code/key-storage-agent/session" "code.vereign.com/code/key-storage-agent/utils" "code.vereign.com/code/viam-apis/authentication" api "code.vereign.com/code/viam-apis/key-storage-agent/api" "github.com/grpc-ecosystem/grpc-gateway/runtime" "golang.org/x/net/context" "google.golang.org/grpc" "google.golang.org/grpc/credentials" "google.golang.org/grpc/metadata" ) // private type for Context keys type contextKey int const ( clientIDKey contextKey = iota ) var pkgCertPEM []byte var pkgKeyPEM []byte var pkgCaCertPEM []byte func credMatcher(headerName string) (mdName string, ok bool) { if headerName == "Session" { return headerName, true } if headerName == "Uuid" { return headerName, true } return "", false } // authenticateAgent check the client credentials func authenticateClient(ctx context.Context, s *handler.KeyStorageServerImpl, invokedMethod string) (string, error) { if md, ok := metadata.FromIncomingContext(ctx); ok { clientAuth := &authentication.Authentication{ Uuid: strings.Join(md["uuid"], ""), Session: strings.Join(md["session"], ""), } viamAuth := &authentication.Authentication{ Uuid: config.SystemAuth.Uuid, Session: config.SystemAuth.Session, } sessionClient := utils.CreateDataStorageClient(viamAuth) defer sessionClient.CloseClient() if clientAuth.Uuid == viamAuth.Uuid { if clientAuth.Session != viamAuth.Session { return "", fmt.Errorf("bad session %s", clientAuth.Session) } } else { if session.CheckSession(clientAuth.Uuid, clientAuth.Session, sessionClient) == false { return "", fmt.Errorf("bad session %s", clientAuth.Session) } } return clientAuth.Uuid, nil } return "", fmt.Errorf("missing credentials") } // unaryInterceptor call authenticateClient with current context func unaryInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler1 grpc.UnaryHandler) (interface{}, error) { s, ok := info.Server.(*handler.KeyStorageServerImpl) if !ok { return nil, fmt.Errorf("unable to cast server") } clientID, err := authenticateClient(ctx, s, info.FullMethod) if err != nil { log.Printf("Error: %v", err) return nil, err } ctx = context.WithValue(ctx, clientIDKey, clientID) return handler1(ctx, req) } func StartGRPCServer(address string, certPEM, privateKeyPEM, caCertPEM, vereignCertPEM, vereignPrivateKeyPEM []byte, dataStorageAddress string, maxMessageSize int) error { pkgCertPEM = certPEM pkgKeyPEM = privateKeyPEM pkgCaCertPEM = caCertPEM // create a listener on TCP port lis, err := net.Listen("tcp", address) if err != nil { log.Printf("Error: %v", err) return fmt.Errorf("failed to listen: %v", err) } // create a server instance s := handler.KeyStorageServerImpl{ DataStorageUrl: dataStorageAddress, CertPEM: certPEM, KeyPEM: privateKeyPEM, CaCertPEM: caCertPEM, VereignCertPEM: vereignCertPEM, VereignPrivateKeyPEM: vereignPrivateKeyPEM, MaxMessageSize: maxMessageSize, } // Load the certificates from PEM Strings certificate, err := tls.X509KeyPair(certPEM, privateKeyPEM) if err != nil { log.Printf("Error: %v", err) return fmt.Errorf("could not load server key pair: %s", err) } // Create a certificate pool from the certificate authority // Get the SystemCertPool, continue with an empty pool on error certPool, _ := x509.SystemCertPool() if certPool == nil { certPool = x509.NewCertPool() } if ok := certPool.AppendCertsFromPEM(caCertPEM); !ok { return fmt.Errorf("failed to append server certs") } // Create the TLS credentials creds := credentials.NewTLS(&tls.Config{ //ClientAuth: tls.RequireAndVerifyClientCert, Certificates: []tls.Certificate{certificate}, ClientCAs: certPool, }) // Create an array of gRPC options with the credentials opts := []grpc.ServerOption{ grpc.Creds(creds), grpc.UnaryInterceptor(unaryInterceptor), grpc.MaxRecvMsgSize(config.MaxMessageSize * 1024 * 1024), } // create a gRPC server object grpcServer := grpc.NewServer(opts...) // attach the CalcMinimumDistance service to the server api.RegisterKeyStorageServer(grpcServer, &s) // start the server log.Printf("starting HTTP/2 gRPC server on %s", address) if err := grpcServer.Serve(lis); err != nil { return fmt.Errorf("failed to serve: %s", err) } return nil } func StartRESTServer(address, grpcURL string, certPEM, keyPEM []byte) error { log.Println("grpcAddress: ", grpcURL) ctx := context.Background() ctx, cancel := context.WithCancel(ctx) defer cancel() mux := runtime.NewServeMux(runtime.WithIncomingHeaderMatcher(credMatcher)) certPool, err := x509.SystemCertPool() if certPool == nil { certPool = x509.NewCertPool() } // Append the client certificates from the CA if ok := certPool.AppendCertsFromPEM(certPEM); !ok { return fmt.Errorf("failed to append client certs") } creds := credentials.NewClientTLSFromCert(certPool, "") // Setup the client gRPC options opts := []grpc.DialOption{grpc.WithTransportCredentials(creds)} // Register RedisStorageServer err = api.RegisterKeyStorageHandlerFromEndpoint(ctx, mux, grpcURL, opts) if err != nil { log.Printf("Error: %v", err) return fmt.Errorf("could not register service RedisStorageServer: %s", err) } // server certificate certificate, err := tls.X509KeyPair(certPEM, keyPEM) if err != nil { log.Printf("Error: %v", err) return fmt.Errorf("could not load server key pair: %s", err) } serverTLSConfig := &tls.Config{ Certificates: []tls.Certificate{certificate}, } tlsServer := &http.Server{ Addr: address, Handler: mux, TLSConfig: serverTLSConfig, } log.Printf("starting HTTP/1.1 REST server on %s", address) tlsServer.ListenAndServeTLS("","") return nil }