From 0b312744773ab7c32dcf9195b43c5eb990c9ca7c Mon Sep 17 00:00:00 2001
From: Manish R Jain <manishrjain@gmail.com>
Date: Wed, 24 Feb 2016 18:17:42 +1100
Subject: [PATCH] Custom encoders for communication between workers via net/rpc
 over TCP. Also, hello there!

---
 conn/client.go   |  57 +++++++++++++++++++++++
 conn/codec.go    |  60 +++++++++++++++++++++++++
 conn/pool.go     |  64 ++++++++++++++++++++++++++
 conn/server.go   |  61 +++++++++++++++++++++++++
 server/main.go   |   1 +
 task.fbs         |   4 ++
 task/XidList.go  |  38 ++++++++++++++++
 test/client.go   | 102 -----------------------------------------
 test/main.go     | 115 -----------------------------------------------
 test/server.go   |  79 --------------------------------
 worker/worker.go |  93 ++++++++++++++++++++++++++++++++++++++
 x/x.go           |   6 +++
 12 files changed, 384 insertions(+), 296 deletions(-)
 create mode 100644 conn/client.go
 create mode 100644 conn/codec.go
 create mode 100644 conn/pool.go
 create mode 100644 conn/server.go
 create mode 100644 task/XidList.go
 delete mode 100644 test/client.go
 delete mode 100644 test/main.go
 delete mode 100644 test/server.go

diff --git a/conn/client.go b/conn/client.go
new file mode 100644
index 00000000..2b668dc0
--- /dev/null
+++ b/conn/client.go
@@ -0,0 +1,57 @@
+package conn
+
+import (
+	"errors"
+	"fmt"
+	"io"
+	"log"
+	"net/rpc"
+)
+
+type ClientCodec struct {
+	Rwc        io.ReadWriteCloser
+	payloadLen int32
+}
+
+func (c *ClientCodec) WriteRequest(r *rpc.Request, body interface{}) error {
+	if body == nil {
+		return fmt.Errorf("Nil request body from client.")
+	}
+
+	query := body.(*Query)
+	if err := writeHeader(c.Rwc, r.Seq, r.ServiceMethod, query.Data); err != nil {
+		return err
+	}
+	n, err := c.Rwc.Write(query.Data)
+	if n != len(query.Data) {
+		return errors.New("Unable to write payload.")
+	}
+	return err
+}
+
+func (c *ClientCodec) ReadResponseHeader(r *rpc.Response) error {
+	if len(r.Error) > 0 {
+		log.Fatal("client got response error: " + r.Error)
+	}
+	if err := parseHeader(c.Rwc, &r.Seq,
+		&r.ServiceMethod, &c.payloadLen); err != nil {
+		return err
+	}
+	return nil
+}
+
+func (c *ClientCodec) ReadResponseBody(body interface{}) error {
+	buf := make([]byte, c.payloadLen)
+	n, err := c.Rwc.Read(buf)
+	if n != int(c.payloadLen) {
+		return fmt.Errorf("ClientCodec expected: %d. Got: %d\n", c.payloadLen, n)
+	}
+
+	reply := body.(*Reply)
+	reply.Data = buf
+	return err
+}
+
+func (c *ClientCodec) Close() error {
+	return c.Rwc.Close()
+}
diff --git a/conn/codec.go b/conn/codec.go
new file mode 100644
index 00000000..e40864a0
--- /dev/null
+++ b/conn/codec.go
@@ -0,0 +1,60 @@
+package conn
+
+import (
+	"bytes"
+	"encoding/binary"
+	"fmt"
+	"io"
+
+	"github.com/dgraph-io/dgraph/x"
+)
+
+type Query struct {
+	Data []byte
+}
+
+type Reply struct {
+	Data []byte
+}
+
+func writeHeader(rwc io.ReadWriteCloser, seq uint64,
+	method string, data []byte) error {
+
+	var bh bytes.Buffer
+	var rerr error
+
+	x.SetError(&rerr, binary.Write(&bh, binary.LittleEndian, seq))
+	x.SetError(&rerr, binary.Write(&bh, binary.LittleEndian, int32(len(method))))
+	x.SetError(&rerr, binary.Write(&bh, binary.LittleEndian, int32(len(data))))
+	_, err := bh.Write([]byte(method))
+	x.SetError(&rerr, err)
+	if rerr != nil {
+		return rerr
+	}
+	_, err = rwc.Write(bh.Bytes())
+	return err
+}
+
+func parseHeader(rwc io.ReadWriteCloser, seq *uint64,
+	method *string, plen *int32) error {
+
+	var err error
+	var sz int32
+	x.SetError(&err, binary.Read(rwc, binary.LittleEndian, seq))
+	x.SetError(&err, binary.Read(rwc, binary.LittleEndian, &sz))
+	x.SetError(&err, binary.Read(rwc, binary.LittleEndian, plen))
+	if err != nil {
+		return err
+	}
+
+	buf := make([]byte, sz)
+	n, err := rwc.Read(buf)
+	if err != nil {
+		return err
+	}
+	if n != int(sz) {
+		return fmt.Errorf("Expected: %v. Got: %v\n", sz, n)
+	}
+	*method = string(buf)
+	return nil
+}
diff --git a/conn/pool.go b/conn/pool.go
new file mode 100644
index 00000000..7168c9cf
--- /dev/null
+++ b/conn/pool.go
@@ -0,0 +1,64 @@
+package conn
+
+import (
+	"net"
+	"net/rpc"
+
+	"github.com/dgraph-io/dgraph/x"
+)
+
+var glog = x.Log("conn")
+
+type Pool struct {
+	clients chan *rpc.Client
+	addr    string
+}
+
+func NewPool(addr string, maxCap int) *Pool {
+	p := new(Pool)
+	p.addr = addr
+	p.clients = make(chan *rpc.Client, maxCap)
+	client, err := p.dialNew()
+	if err != nil {
+		glog.Fatal(err)
+		return nil
+	}
+	p.clients <- client
+	return p
+}
+
+func (p *Pool) dialNew() (*rpc.Client, error) {
+	nconn, err := net.Dial("tcp", p.addr)
+	if err != nil {
+		return nil, err
+	}
+	cc := &ClientCodec{
+		Rwc: nconn,
+	}
+	return rpc.NewClientWithCodec(cc), nil
+}
+
+func (p *Pool) Get() (*rpc.Client, error) {
+	select {
+	case client := <-p.clients:
+		return client, nil
+	default:
+		return p.dialNew()
+	}
+}
+
+func (p *Pool) Put(client *rpc.Client) error {
+	select {
+	case p.clients <- client:
+		return nil
+	default:
+		return client.Close()
+	}
+}
+
+func (p *Pool) Close() error {
+	// We're not doing a clean exit here. A clean exit here would require
+	// mutex locks around conns; which seems unnecessary just to shut down
+	// the server.
+	return nil
+}
diff --git a/conn/server.go b/conn/server.go
new file mode 100644
index 00000000..986783a5
--- /dev/null
+++ b/conn/server.go
@@ -0,0 +1,61 @@
+package conn
+
+import (
+	"errors"
+	"io"
+	"log"
+	"net/rpc"
+)
+
+type ServerCodec struct {
+	Rwc        io.ReadWriteCloser
+	payloadLen int32
+}
+
+func (c *ServerCodec) ReadRequestHeader(r *rpc.Request) error {
+	return parseHeader(c.Rwc, &r.Seq, &r.ServiceMethod, &c.payloadLen)
+}
+
+func (c *ServerCodec) ReadRequestBody(data interface{}) error {
+	b := make([]byte, c.payloadLen)
+	n, err := c.Rwc.Read(b)
+	if err != nil {
+		log.Fatal("server", err)
+	}
+	if n != int(c.payloadLen) {
+		return errors.New("ServerCodec unable to read request.")
+	}
+
+	if data == nil {
+		// If data is nil, discard this request.
+		return nil
+	}
+	query := data.(*Query)
+	query.Data = b
+	return nil
+}
+
+func (c *ServerCodec) WriteResponse(resp *rpc.Response, data interface{}) error {
+	if len(resp.Error) > 0 {
+		log.Fatal("Response has error: " + resp.Error)
+	}
+	if data == nil {
+		log.Fatal("Worker write response data is nil")
+	}
+	reply, ok := data.(*Reply)
+	if !ok {
+		log.Fatal("Unable to convert to reply")
+	}
+
+	if err := writeHeader(c.Rwc, resp.Seq,
+		resp.ServiceMethod, reply.Data); err != nil {
+		return err
+	}
+
+	_, err := c.Rwc.Write(reply.Data)
+	return err
+}
+
+func (c *ServerCodec) Close() error {
+	return c.Rwc.Close()
+}
diff --git a/server/main.go b/server/main.go
index e06c931c..86e2d765 100644
--- a/server/main.go
+++ b/server/main.go
@@ -141,6 +141,7 @@ func main() {
 
 	posting.Init(clog)
 	worker.Init(ps)
+	worker.Connect()
 	uid.Init(ps)
 
 	http.HandleFunc("/query", queryHandler(ps))
diff --git a/task.fbs b/task.fbs
index 39e877cf..2e3acbb0 100644
--- a/task.fbs
+++ b/task.fbs
@@ -9,6 +9,10 @@ table Value {
 	val:[ubyte];
 }
 
+table XidList {
+	xids:[string];
+}
+
 table UidList {
 	uids:[ulong];
 }
diff --git a/task/XidList.go b/task/XidList.go
new file mode 100644
index 00000000..90798a8d
--- /dev/null
+++ b/task/XidList.go
@@ -0,0 +1,38 @@
+// automatically generated, do not modify
+
+package task
+
+import (
+	flatbuffers "github.com/google/flatbuffers/go"
+)
+type XidList struct {
+	_tab flatbuffers.Table
+}
+
+func (rcv *XidList) Init(buf []byte, i flatbuffers.UOffsetT) {
+	rcv._tab.Bytes = buf
+	rcv._tab.Pos = i
+}
+
+func (rcv *XidList) Xids(j int) []byte {
+	o := flatbuffers.UOffsetT(rcv._tab.Offset(4))
+	if o != 0 {
+		a := rcv._tab.Vector(o)
+		return rcv._tab.ByteVector(a + flatbuffers.UOffsetT(j * 4))
+	}
+	return nil
+}
+
+func (rcv *XidList) XidsLength() int {
+	o := flatbuffers.UOffsetT(rcv._tab.Offset(4))
+	if o != 0 {
+		return rcv._tab.VectorLen(o)
+	}
+	return 0
+}
+
+func XidListStart(builder *flatbuffers.Builder) { builder.StartObject(1) }
+func XidListAddXids(builder *flatbuffers.Builder, xids flatbuffers.UOffsetT) { builder.PrependUOffsetTSlot(0, flatbuffers.UOffsetT(xids), 0) }
+func XidListStartXidsVector(builder *flatbuffers.Builder, numElems int) flatbuffers.UOffsetT { return builder.StartVector(4, numElems, 4)
+}
+func XidListEnd(builder *flatbuffers.Builder) flatbuffers.UOffsetT { return builder.EndObject() }
diff --git a/test/client.go b/test/client.go
deleted file mode 100644
index 49fc8773..00000000
--- a/test/client.go
+++ /dev/null
@@ -1,102 +0,0 @@
-package main
-
-import (
-	"bufio"
-	"bytes"
-	"encoding/binary"
-	"errors"
-	"fmt"
-	"io"
-	"log"
-	"net/rpc"
-)
-
-type ccodec struct {
-	rwc        io.ReadWriteCloser
-	ebuf       *bufio.Writer
-	payloadLen int32
-}
-
-func writeHeader(rwc io.ReadWriteCloser, seq uint64,
-	method string, data []byte) error {
-
-	var bh bytes.Buffer
-	var rerr error
-
-	setError(&rerr, binary.Write(&bh, binary.LittleEndian, seq))
-	setError(&rerr, binary.Write(&bh, binary.LittleEndian, int32(len(method))))
-	setError(&rerr, binary.Write(&bh, binary.LittleEndian, int32(len(data))))
-	_, err := bh.Write([]byte(method))
-	setError(&rerr, err)
-	if rerr != nil {
-		return rerr
-	}
-	_, err = rwc.Write(bh.Bytes())
-	return err
-}
-
-func parseHeader(rwc io.ReadWriteCloser, seq *uint64, method *string, plen *int32) error {
-	var err error
-	var sz int32
-	setError(&err, binary.Read(rwc, binary.LittleEndian, seq))
-	setError(&err, binary.Read(rwc, binary.LittleEndian, &sz))
-	setError(&err, binary.Read(rwc, binary.LittleEndian, plen))
-	if err != nil {
-		return err
-	}
-	buf := make([]byte, sz)
-	n, err := rwc.Read(buf)
-	if err != nil {
-		return err
-	}
-	if n != int(sz) {
-		return fmt.Errorf("Expected: %v. Got: %v\n", sz, n)
-	}
-	*method = string(buf)
-	return nil
-}
-
-func (c *ccodec) WriteRequest(r *rpc.Request, body interface{}) error {
-	if body == nil {
-		return errors.New("Nil body")
-	}
-
-	query := body.(*Query)
-	if err := writeHeader(c.rwc, r.Seq, r.ServiceMethod, query.d); err != nil {
-		return err
-	}
-
-	n, err := c.rwc.Write(query.d)
-	if n != len(query.d) {
-		return errors.New("Unable to write payload.")
-	}
-	return err
-}
-
-func (c *ccodec) ReadResponseHeader(r *rpc.Response) error {
-	if len(r.Error) > 0 {
-		log.Fatal("client got response error: " + r.Error)
-	}
-	if err := parseHeader(c.rwc, &r.Seq,
-		&r.ServiceMethod, &c.payloadLen); err != nil {
-		return err
-	}
-	fmt.Println("Client got response:", r.Seq)
-	fmt.Println("Client got response:", r.ServiceMethod)
-	return nil
-}
-
-func (c *ccodec) ReadResponseBody(body interface{}) error {
-	buf := make([]byte, c.payloadLen)
-	n, err := c.rwc.Read(buf)
-	if n != int(c.payloadLen) {
-		return fmt.Errorf("Client expected: %d. Got: %d\n", c.payloadLen, n)
-	}
-	reply := body.(*Reply)
-	reply.d = buf
-	return err
-}
-
-func (c *ccodec) Close() error {
-	return c.rwc.Close()
-}
diff --git a/test/main.go b/test/main.go
deleted file mode 100644
index 31dda1b8..00000000
--- a/test/main.go
+++ /dev/null
@@ -1,115 +0,0 @@
-package main
-
-import (
-	"bufio"
-	"fmt"
-	"io"
-	"log"
-	"math/rand"
-	"net"
-	"net/rpc"
-)
-
-type Query struct {
-	d []byte
-}
-
-type Reply struct {
-	d []byte
-}
-
-func setError(prev *error, n error) {
-	if prev == nil {
-		prev = &n
-	}
-}
-
-type Worker struct {
-}
-
-func serveIt(conn io.ReadWriteCloser) {
-	for {
-		srv := &scodec{
-			rwc:  conn,
-			ebuf: bufio.NewWriter(conn),
-		}
-		rpc.ServeRequest(srv)
-	}
-}
-
-func (w *Worker) Receive(query *Query, reply *Reply) error {
-	fmt.Printf("Worker received: [%s]\n", string(query.d))
-	reply.d = []byte("abcdefghij-Hello World!")
-	return nil
-}
-
-func runServer(address string) error {
-	w := new(Worker)
-	if err := rpc.Register(w); err != nil {
-		return err
-	}
-
-	ln, err := net.Listen("tcp", address)
-	if err != nil {
-		fmt.Printf("listen(%q): %s\n", address, err)
-		return err
-	}
-	fmt.Printf("Worker listening on %s\n", ln.Addr())
-	go func() {
-		for {
-			cxn, err := ln.Accept()
-			if err != nil {
-				log.Fatalf("listen(%q): %s\n", address, err)
-				return
-			}
-			log.Printf("Worker accepted connection to %s from %s\n",
-				cxn.LocalAddr(), cxn.RemoteAddr())
-			go serveIt(cxn)
-		}
-	}()
-	return nil
-}
-
-func main() {
-	addresses := map[int]string{
-		1: "127.0.0.1:10000",
-		// 2: "127.0.0.1:10001",
-		// 3: "127.0.0.1:10002",
-	}
-
-	for _, address := range addresses {
-		runServer(address)
-	}
-
-	clients := make(map[int]*rpc.Client)
-	for id, address := range addresses {
-		conn, err := net.Dial("tcp", address)
-		if err != nil {
-			log.Fatal("dial", err)
-		}
-		cc := &ccodec{
-			rwc:  conn,
-			ebuf: bufio.NewWriter(conn),
-		}
-		clients[id] = rpc.NewClientWithCodec(cc)
-	}
-
-	for i := 0; i < 1; i++ {
-		client := clients[1]
-		if client == nil {
-			log.Fatal("Worker is nil")
-		}
-
-		id := 0
-		// for id, server := range servers {
-		query := new(Query)
-		query.d = []byte(fmt.Sprintf("id:%d Rand: %d", id, rand.Int()))
-		reply := new(Reply)
-		if err := client.Call("Worker.Receive", query, reply); err != nil {
-			log.Fatal("call", err)
-		}
-
-		fmt.Printf("Returned: %s\n", string(reply.d))
-		// }
-	}
-}
diff --git a/test/server.go b/test/server.go
deleted file mode 100644
index b4ff9255..00000000
--- a/test/server.go
+++ /dev/null
@@ -1,79 +0,0 @@
-package main
-
-import (
-	"bufio"
-	"errors"
-	"fmt"
-	"io"
-	"log"
-	"net/rpc"
-	"reflect"
-)
-
-type scodec struct {
-	rwc        io.ReadWriteCloser
-	ebuf       *bufio.Writer
-	payloadLen int32
-}
-
-func (c *scodec) ReadRequestHeader(r *rpc.Request) error {
-	var err error
-	if err = parseHeader(c.rwc, &r.Seq,
-		&r.ServiceMethod, &c.payloadLen); err != nil {
-		return err
-	}
-
-	fmt.Println("server using custom codec to read header")
-	fmt.Println("server method called:", r.ServiceMethod)
-	fmt.Println("server method called:", r.Seq)
-	return nil
-}
-
-func (c *scodec) ReadRequestBody(data interface{}) error {
-	if data == nil {
-		log.Fatal("Why is data nil here?")
-	}
-	value := reflect.ValueOf(data)
-	if value.Type().Kind() != reflect.Ptr {
-		log.Fatal("Should of of type pointer")
-	}
-
-	b := make([]byte, c.payloadLen)
-	n, err := c.rwc.Read(b)
-	fmt.Printf("Worker read n bytes: %v %s\n", n, string(b))
-	if err != nil {
-		log.Fatal("server", err)
-	}
-	if n != int(c.payloadLen) {
-		return errors.New("Server unable to read request.")
-	}
-
-	query := data.(*Query)
-	query.d = b
-	return nil
-}
-
-func (c *scodec) WriteResponse(resp *rpc.Response, data interface{}) error {
-	if len(resp.Error) > 0 {
-		log.Fatal("Response has error: " + resp.Error)
-	}
-	if data == nil {
-		log.Fatal("Worker write response data is nil")
-	}
-	reply, ok := data.(*Reply)
-	if !ok {
-		log.Fatal("Unable to convert to reply")
-	}
-
-	if err := writeHeader(c.rwc, resp.Seq,
-		resp.ServiceMethod, reply.d); err != nil {
-		return err
-	}
-
-	_, err := c.rwc.Write(reply.d)
-	return err
-}
-
-func (c *scodec) Close() error {
-	return c.rwc.Close()
-}
diff --git a/worker/worker.go b/worker/worker.go
index 5538e0fa..5966a632 100644
--- a/worker/worker.go
+++ b/worker/worker.go
@@ -1,6 +1,13 @@
 package worker
 
 import (
+	"flag"
+	"io"
+	"net"
+	"net/rpc"
+	"strings"
+
+	"github.com/dgraph-io/dgraph/conn"
 	"github.com/dgraph-io/dgraph/posting"
 	"github.com/dgraph-io/dgraph/store"
 	"github.com/dgraph-io/dgraph/task"
@@ -8,12 +15,53 @@ import (
 	"github.com/google/flatbuffers/go"
 )
 
+var workers = flag.String("workers", "",
+	"Comma separated list of IP addresses of workers")
+var workerPort = flag.String("workerport", ":12345",
+	"Port used by worker for internal communication.")
+
+var glog = x.Log("worker")
 var dataStore *store.Store
+var pools []*conn.Pool
 
 func Init(ps *store.Store) {
 	dataStore = ps
 }
 
+func Connect() {
+	w := new(Worker)
+	if err := rpc.Register(w); err != nil {
+		glog.Fatal(err)
+	}
+	if err := runServer(*workerPort); err != nil {
+		glog.Fatal(err)
+	}
+
+	addrs := strings.Split(*workers, ",")
+	var pools []*conn.Pool
+	for _, addr := range addrs {
+		if len(addr) == 0 {
+			continue
+		}
+		pool := conn.NewPool(addr, 5)
+		client, err := pool.Get()
+		if err != nil {
+			glog.Fatal(err)
+		}
+		query := new(conn.Query)
+		query.Data = []byte("hello")
+		reply := new(conn.Reply)
+		if err = client.Call("Worker.Hello", query, reply); err != nil {
+			glog.WithField("call", "Worker.Hello").Fatal(err)
+		}
+		glog.WithField("reply", string(reply.Data)).WithField("addr", addr).
+			Info("Got reply from server")
+		pools = append(pools, pool)
+	}
+
+	glog.Info("Server started. Clients connected.")
+}
+
 func ProcessTask(query []byte) (result []byte, rerr error) {
 	uo := flatbuffers.GetUOffsetT(query)
 	q := new(task.Query)
@@ -78,3 +126,48 @@ func NewQuery(attr string, uids []uint64) []byte {
 	b.Finish(qend)
 	return b.Bytes[b.Head():]
 }
+
+type Worker struct {
+}
+
+func (w *Worker) Hello(query *conn.Query, reply *conn.Reply) error {
+	if string(query.Data) == "hello" {
+		reply.Data = []byte("Oh hello there!")
+	} else {
+		reply.Data = []byte("Hey stranger!")
+	}
+	return nil
+}
+
+func serveRequests(irwc io.ReadWriteCloser) {
+	for {
+		sc := &conn.ServerCodec{
+			Rwc: irwc,
+		}
+		rpc.ServeRequest(sc)
+	}
+}
+
+func runServer(address string) error {
+	ln, err := net.Listen("tcp", address)
+	if err != nil {
+		glog.Fatalf("While running server: %v", err)
+		return err
+	}
+	glog.WithField("address", ln.Addr()).Info("Worker listening")
+
+	go func() {
+		for {
+			cxn, err := ln.Accept()
+			if err != nil {
+				glog.Fatalf("listen(%q): %s\n", address, err)
+				return
+			}
+			glog.WithField("local", cxn.LocalAddr()).
+				WithField("remote", cxn.RemoteAddr()).
+				Debug("Worker accepted connection")
+			go serveRequests(cxn)
+		}
+	}()
+	return nil
+}
diff --git a/x/x.go b/x/x.go
index a9e29e48..2ae25556 100644
--- a/x/x.go
+++ b/x/x.go
@@ -39,6 +39,12 @@ type DirectedEdge struct {
 	Timestamp time.Time
 }
 
+func SetError(prev *error, n error) {
+	if prev == nil {
+		prev = &n
+	}
+}
+
 func Log(p string) *logrus.Entry {
 	l := logrus.WithFields(logrus.Fields{
 		"package": p,
-- 
GitLab