From 0b312744773ab7c32dcf9195b43c5eb990c9ca7c Mon Sep 17 00:00:00 2001 From: Manish R Jain <manishrjain@gmail.com> Date: Wed, 24 Feb 2016 18:17:42 +1100 Subject: [PATCH] Custom encoders for communication between workers via net/rpc over TCP. Also, hello there! --- conn/client.go | 57 +++++++++++++++++++++++ conn/codec.go | 60 +++++++++++++++++++++++++ conn/pool.go | 64 ++++++++++++++++++++++++++ conn/server.go | 61 +++++++++++++++++++++++++ server/main.go | 1 + task.fbs | 4 ++ task/XidList.go | 38 ++++++++++++++++ test/client.go | 102 ----------------------------------------- test/main.go | 115 ----------------------------------------------- test/server.go | 79 -------------------------------- worker/worker.go | 93 ++++++++++++++++++++++++++++++++++++++ x/x.go | 6 +++ 12 files changed, 384 insertions(+), 296 deletions(-) create mode 100644 conn/client.go create mode 100644 conn/codec.go create mode 100644 conn/pool.go create mode 100644 conn/server.go create mode 100644 task/XidList.go delete mode 100644 test/client.go delete mode 100644 test/main.go delete mode 100644 test/server.go diff --git a/conn/client.go b/conn/client.go new file mode 100644 index 00000000..2b668dc0 --- /dev/null +++ b/conn/client.go @@ -0,0 +1,57 @@ +package conn + +import ( + "errors" + "fmt" + "io" + "log" + "net/rpc" +) + +type ClientCodec struct { + Rwc io.ReadWriteCloser + payloadLen int32 +} + +func (c *ClientCodec) WriteRequest(r *rpc.Request, body interface{}) error { + if body == nil { + return fmt.Errorf("Nil request body from client.") + } + + query := body.(*Query) + if err := writeHeader(c.Rwc, r.Seq, r.ServiceMethod, query.Data); err != nil { + return err + } + n, err := c.Rwc.Write(query.Data) + if n != len(query.Data) { + return errors.New("Unable to write payload.") + } + return err +} + +func (c *ClientCodec) ReadResponseHeader(r *rpc.Response) error { + if len(r.Error) > 0 { + log.Fatal("client got response error: " + r.Error) + } + if err := parseHeader(c.Rwc, &r.Seq, + &r.ServiceMethod, &c.payloadLen); err != nil { + return err + } + return nil +} + +func (c *ClientCodec) ReadResponseBody(body interface{}) error { + buf := make([]byte, c.payloadLen) + n, err := c.Rwc.Read(buf) + if n != int(c.payloadLen) { + return fmt.Errorf("ClientCodec expected: %d. Got: %d\n", c.payloadLen, n) + } + + reply := body.(*Reply) + reply.Data = buf + return err +} + +func (c *ClientCodec) Close() error { + return c.Rwc.Close() +} diff --git a/conn/codec.go b/conn/codec.go new file mode 100644 index 00000000..e40864a0 --- /dev/null +++ b/conn/codec.go @@ -0,0 +1,60 @@ +package conn + +import ( + "bytes" + "encoding/binary" + "fmt" + "io" + + "github.com/dgraph-io/dgraph/x" +) + +type Query struct { + Data []byte +} + +type Reply struct { + Data []byte +} + +func writeHeader(rwc io.ReadWriteCloser, seq uint64, + method string, data []byte) error { + + var bh bytes.Buffer + var rerr error + + x.SetError(&rerr, binary.Write(&bh, binary.LittleEndian, seq)) + x.SetError(&rerr, binary.Write(&bh, binary.LittleEndian, int32(len(method)))) + x.SetError(&rerr, binary.Write(&bh, binary.LittleEndian, int32(len(data)))) + _, err := bh.Write([]byte(method)) + x.SetError(&rerr, err) + if rerr != nil { + return rerr + } + _, err = rwc.Write(bh.Bytes()) + return err +} + +func parseHeader(rwc io.ReadWriteCloser, seq *uint64, + method *string, plen *int32) error { + + var err error + var sz int32 + x.SetError(&err, binary.Read(rwc, binary.LittleEndian, seq)) + x.SetError(&err, binary.Read(rwc, binary.LittleEndian, &sz)) + x.SetError(&err, binary.Read(rwc, binary.LittleEndian, plen)) + if err != nil { + return err + } + + buf := make([]byte, sz) + n, err := rwc.Read(buf) + if err != nil { + return err + } + if n != int(sz) { + return fmt.Errorf("Expected: %v. Got: %v\n", sz, n) + } + *method = string(buf) + return nil +} diff --git a/conn/pool.go b/conn/pool.go new file mode 100644 index 00000000..7168c9cf --- /dev/null +++ b/conn/pool.go @@ -0,0 +1,64 @@ +package conn + +import ( + "net" + "net/rpc" + + "github.com/dgraph-io/dgraph/x" +) + +var glog = x.Log("conn") + +type Pool struct { + clients chan *rpc.Client + addr string +} + +func NewPool(addr string, maxCap int) *Pool { + p := new(Pool) + p.addr = addr + p.clients = make(chan *rpc.Client, maxCap) + client, err := p.dialNew() + if err != nil { + glog.Fatal(err) + return nil + } + p.clients <- client + return p +} + +func (p *Pool) dialNew() (*rpc.Client, error) { + nconn, err := net.Dial("tcp", p.addr) + if err != nil { + return nil, err + } + cc := &ClientCodec{ + Rwc: nconn, + } + return rpc.NewClientWithCodec(cc), nil +} + +func (p *Pool) Get() (*rpc.Client, error) { + select { + case client := <-p.clients: + return client, nil + default: + return p.dialNew() + } +} + +func (p *Pool) Put(client *rpc.Client) error { + select { + case p.clients <- client: + return nil + default: + return client.Close() + } +} + +func (p *Pool) Close() error { + // We're not doing a clean exit here. A clean exit here would require + // mutex locks around conns; which seems unnecessary just to shut down + // the server. + return nil +} diff --git a/conn/server.go b/conn/server.go new file mode 100644 index 00000000..986783a5 --- /dev/null +++ b/conn/server.go @@ -0,0 +1,61 @@ +package conn + +import ( + "errors" + "io" + "log" + "net/rpc" +) + +type ServerCodec struct { + Rwc io.ReadWriteCloser + payloadLen int32 +} + +func (c *ServerCodec) ReadRequestHeader(r *rpc.Request) error { + return parseHeader(c.Rwc, &r.Seq, &r.ServiceMethod, &c.payloadLen) +} + +func (c *ServerCodec) ReadRequestBody(data interface{}) error { + b := make([]byte, c.payloadLen) + n, err := c.Rwc.Read(b) + if err != nil { + log.Fatal("server", err) + } + if n != int(c.payloadLen) { + return errors.New("ServerCodec unable to read request.") + } + + if data == nil { + // If data is nil, discard this request. + return nil + } + query := data.(*Query) + query.Data = b + return nil +} + +func (c *ServerCodec) WriteResponse(resp *rpc.Response, data interface{}) error { + if len(resp.Error) > 0 { + log.Fatal("Response has error: " + resp.Error) + } + if data == nil { + log.Fatal("Worker write response data is nil") + } + reply, ok := data.(*Reply) + if !ok { + log.Fatal("Unable to convert to reply") + } + + if err := writeHeader(c.Rwc, resp.Seq, + resp.ServiceMethod, reply.Data); err != nil { + return err + } + + _, err := c.Rwc.Write(reply.Data) + return err +} + +func (c *ServerCodec) Close() error { + return c.Rwc.Close() +} diff --git a/server/main.go b/server/main.go index e06c931c..86e2d765 100644 --- a/server/main.go +++ b/server/main.go @@ -141,6 +141,7 @@ func main() { posting.Init(clog) worker.Init(ps) + worker.Connect() uid.Init(ps) http.HandleFunc("/query", queryHandler(ps)) diff --git a/task.fbs b/task.fbs index 39e877cf..2e3acbb0 100644 --- a/task.fbs +++ b/task.fbs @@ -9,6 +9,10 @@ table Value { val:[ubyte]; } +table XidList { + xids:[string]; +} + table UidList { uids:[ulong]; } diff --git a/task/XidList.go b/task/XidList.go new file mode 100644 index 00000000..90798a8d --- /dev/null +++ b/task/XidList.go @@ -0,0 +1,38 @@ +// automatically generated, do not modify + +package task + +import ( + flatbuffers "github.com/google/flatbuffers/go" +) +type XidList struct { + _tab flatbuffers.Table +} + +func (rcv *XidList) Init(buf []byte, i flatbuffers.UOffsetT) { + rcv._tab.Bytes = buf + rcv._tab.Pos = i +} + +func (rcv *XidList) Xids(j int) []byte { + o := flatbuffers.UOffsetT(rcv._tab.Offset(4)) + if o != 0 { + a := rcv._tab.Vector(o) + return rcv._tab.ByteVector(a + flatbuffers.UOffsetT(j * 4)) + } + return nil +} + +func (rcv *XidList) XidsLength() int { + o := flatbuffers.UOffsetT(rcv._tab.Offset(4)) + if o != 0 { + return rcv._tab.VectorLen(o) + } + return 0 +} + +func XidListStart(builder *flatbuffers.Builder) { builder.StartObject(1) } +func XidListAddXids(builder *flatbuffers.Builder, xids flatbuffers.UOffsetT) { builder.PrependUOffsetTSlot(0, flatbuffers.UOffsetT(xids), 0) } +func XidListStartXidsVector(builder *flatbuffers.Builder, numElems int) flatbuffers.UOffsetT { return builder.StartVector(4, numElems, 4) +} +func XidListEnd(builder *flatbuffers.Builder) flatbuffers.UOffsetT { return builder.EndObject() } diff --git a/test/client.go b/test/client.go deleted file mode 100644 index 49fc8773..00000000 --- a/test/client.go +++ /dev/null @@ -1,102 +0,0 @@ -package main - -import ( - "bufio" - "bytes" - "encoding/binary" - "errors" - "fmt" - "io" - "log" - "net/rpc" -) - -type ccodec struct { - rwc io.ReadWriteCloser - ebuf *bufio.Writer - payloadLen int32 -} - -func writeHeader(rwc io.ReadWriteCloser, seq uint64, - method string, data []byte) error { - - var bh bytes.Buffer - var rerr error - - setError(&rerr, binary.Write(&bh, binary.LittleEndian, seq)) - setError(&rerr, binary.Write(&bh, binary.LittleEndian, int32(len(method)))) - setError(&rerr, binary.Write(&bh, binary.LittleEndian, int32(len(data)))) - _, err := bh.Write([]byte(method)) - setError(&rerr, err) - if rerr != nil { - return rerr - } - _, err = rwc.Write(bh.Bytes()) - return err -} - -func parseHeader(rwc io.ReadWriteCloser, seq *uint64, method *string, plen *int32) error { - var err error - var sz int32 - setError(&err, binary.Read(rwc, binary.LittleEndian, seq)) - setError(&err, binary.Read(rwc, binary.LittleEndian, &sz)) - setError(&err, binary.Read(rwc, binary.LittleEndian, plen)) - if err != nil { - return err - } - buf := make([]byte, sz) - n, err := rwc.Read(buf) - if err != nil { - return err - } - if n != int(sz) { - return fmt.Errorf("Expected: %v. Got: %v\n", sz, n) - } - *method = string(buf) - return nil -} - -func (c *ccodec) WriteRequest(r *rpc.Request, body interface{}) error { - if body == nil { - return errors.New("Nil body") - } - - query := body.(*Query) - if err := writeHeader(c.rwc, r.Seq, r.ServiceMethod, query.d); err != nil { - return err - } - - n, err := c.rwc.Write(query.d) - if n != len(query.d) { - return errors.New("Unable to write payload.") - } - return err -} - -func (c *ccodec) ReadResponseHeader(r *rpc.Response) error { - if len(r.Error) > 0 { - log.Fatal("client got response error: " + r.Error) - } - if err := parseHeader(c.rwc, &r.Seq, - &r.ServiceMethod, &c.payloadLen); err != nil { - return err - } - fmt.Println("Client got response:", r.Seq) - fmt.Println("Client got response:", r.ServiceMethod) - return nil -} - -func (c *ccodec) ReadResponseBody(body interface{}) error { - buf := make([]byte, c.payloadLen) - n, err := c.rwc.Read(buf) - if n != int(c.payloadLen) { - return fmt.Errorf("Client expected: %d. Got: %d\n", c.payloadLen, n) - } - reply := body.(*Reply) - reply.d = buf - return err -} - -func (c *ccodec) Close() error { - return c.rwc.Close() -} diff --git a/test/main.go b/test/main.go deleted file mode 100644 index 31dda1b8..00000000 --- a/test/main.go +++ /dev/null @@ -1,115 +0,0 @@ -package main - -import ( - "bufio" - "fmt" - "io" - "log" - "math/rand" - "net" - "net/rpc" -) - -type Query struct { - d []byte -} - -type Reply struct { - d []byte -} - -func setError(prev *error, n error) { - if prev == nil { - prev = &n - } -} - -type Worker struct { -} - -func serveIt(conn io.ReadWriteCloser) { - for { - srv := &scodec{ - rwc: conn, - ebuf: bufio.NewWriter(conn), - } - rpc.ServeRequest(srv) - } -} - -func (w *Worker) Receive(query *Query, reply *Reply) error { - fmt.Printf("Worker received: [%s]\n", string(query.d)) - reply.d = []byte("abcdefghij-Hello World!") - return nil -} - -func runServer(address string) error { - w := new(Worker) - if err := rpc.Register(w); err != nil { - return err - } - - ln, err := net.Listen("tcp", address) - if err != nil { - fmt.Printf("listen(%q): %s\n", address, err) - return err - } - fmt.Printf("Worker listening on %s\n", ln.Addr()) - go func() { - for { - cxn, err := ln.Accept() - if err != nil { - log.Fatalf("listen(%q): %s\n", address, err) - return - } - log.Printf("Worker accepted connection to %s from %s\n", - cxn.LocalAddr(), cxn.RemoteAddr()) - go serveIt(cxn) - } - }() - return nil -} - -func main() { - addresses := map[int]string{ - 1: "127.0.0.1:10000", - // 2: "127.0.0.1:10001", - // 3: "127.0.0.1:10002", - } - - for _, address := range addresses { - runServer(address) - } - - clients := make(map[int]*rpc.Client) - for id, address := range addresses { - conn, err := net.Dial("tcp", address) - if err != nil { - log.Fatal("dial", err) - } - cc := &ccodec{ - rwc: conn, - ebuf: bufio.NewWriter(conn), - } - clients[id] = rpc.NewClientWithCodec(cc) - } - - for i := 0; i < 1; i++ { - client := clients[1] - if client == nil { - log.Fatal("Worker is nil") - } - - id := 0 - // for id, server := range servers { - query := new(Query) - query.d = []byte(fmt.Sprintf("id:%d Rand: %d", id, rand.Int())) - reply := new(Reply) - if err := client.Call("Worker.Receive", query, reply); err != nil { - log.Fatal("call", err) - } - - fmt.Printf("Returned: %s\n", string(reply.d)) - // } - } -} diff --git a/test/server.go b/test/server.go deleted file mode 100644 index b4ff9255..00000000 --- a/test/server.go +++ /dev/null @@ -1,79 +0,0 @@ -package main - -import ( - "bufio" - "errors" - "fmt" - "io" - "log" - "net/rpc" - "reflect" -) - -type scodec struct { - rwc io.ReadWriteCloser - ebuf *bufio.Writer - payloadLen int32 -} - -func (c *scodec) ReadRequestHeader(r *rpc.Request) error { - var err error - if err = parseHeader(c.rwc, &r.Seq, - &r.ServiceMethod, &c.payloadLen); err != nil { - return err - } - - fmt.Println("server using custom codec to read header") - fmt.Println("server method called:", r.ServiceMethod) - fmt.Println("server method called:", r.Seq) - return nil -} - -func (c *scodec) ReadRequestBody(data interface{}) error { - if data == nil { - log.Fatal("Why is data nil here?") - } - value := reflect.ValueOf(data) - if value.Type().Kind() != reflect.Ptr { - log.Fatal("Should of of type pointer") - } - - b := make([]byte, c.payloadLen) - n, err := c.rwc.Read(b) - fmt.Printf("Worker read n bytes: %v %s\n", n, string(b)) - if err != nil { - log.Fatal("server", err) - } - if n != int(c.payloadLen) { - return errors.New("Server unable to read request.") - } - - query := data.(*Query) - query.d = b - return nil -} - -func (c *scodec) WriteResponse(resp *rpc.Response, data interface{}) error { - if len(resp.Error) > 0 { - log.Fatal("Response has error: " + resp.Error) - } - if data == nil { - log.Fatal("Worker write response data is nil") - } - reply, ok := data.(*Reply) - if !ok { - log.Fatal("Unable to convert to reply") - } - - if err := writeHeader(c.rwc, resp.Seq, - resp.ServiceMethod, reply.d); err != nil { - return err - } - - _, err := c.rwc.Write(reply.d) - return err -} - -func (c *scodec) Close() error { - return c.rwc.Close() -} diff --git a/worker/worker.go b/worker/worker.go index 5538e0fa..5966a632 100644 --- a/worker/worker.go +++ b/worker/worker.go @@ -1,6 +1,13 @@ package worker import ( + "flag" + "io" + "net" + "net/rpc" + "strings" + + "github.com/dgraph-io/dgraph/conn" "github.com/dgraph-io/dgraph/posting" "github.com/dgraph-io/dgraph/store" "github.com/dgraph-io/dgraph/task" @@ -8,12 +15,53 @@ import ( "github.com/google/flatbuffers/go" ) +var workers = flag.String("workers", "", + "Comma separated list of IP addresses of workers") +var workerPort = flag.String("workerport", ":12345", + "Port used by worker for internal communication.") + +var glog = x.Log("worker") var dataStore *store.Store +var pools []*conn.Pool func Init(ps *store.Store) { dataStore = ps } +func Connect() { + w := new(Worker) + if err := rpc.Register(w); err != nil { + glog.Fatal(err) + } + if err := runServer(*workerPort); err != nil { + glog.Fatal(err) + } + + addrs := strings.Split(*workers, ",") + var pools []*conn.Pool + for _, addr := range addrs { + if len(addr) == 0 { + continue + } + pool := conn.NewPool(addr, 5) + client, err := pool.Get() + if err != nil { + glog.Fatal(err) + } + query := new(conn.Query) + query.Data = []byte("hello") + reply := new(conn.Reply) + if err = client.Call("Worker.Hello", query, reply); err != nil { + glog.WithField("call", "Worker.Hello").Fatal(err) + } + glog.WithField("reply", string(reply.Data)).WithField("addr", addr). + Info("Got reply from server") + pools = append(pools, pool) + } + + glog.Info("Server started. Clients connected.") +} + func ProcessTask(query []byte) (result []byte, rerr error) { uo := flatbuffers.GetUOffsetT(query) q := new(task.Query) @@ -78,3 +126,48 @@ func NewQuery(attr string, uids []uint64) []byte { b.Finish(qend) return b.Bytes[b.Head():] } + +type Worker struct { +} + +func (w *Worker) Hello(query *conn.Query, reply *conn.Reply) error { + if string(query.Data) == "hello" { + reply.Data = []byte("Oh hello there!") + } else { + reply.Data = []byte("Hey stranger!") + } + return nil +} + +func serveRequests(irwc io.ReadWriteCloser) { + for { + sc := &conn.ServerCodec{ + Rwc: irwc, + } + rpc.ServeRequest(sc) + } +} + +func runServer(address string) error { + ln, err := net.Listen("tcp", address) + if err != nil { + glog.Fatalf("While running server: %v", err) + return err + } + glog.WithField("address", ln.Addr()).Info("Worker listening") + + go func() { + for { + cxn, err := ln.Accept() + if err != nil { + glog.Fatalf("listen(%q): %s\n", address, err) + return + } + glog.WithField("local", cxn.LocalAddr()). + WithField("remote", cxn.RemoteAddr()). + Debug("Worker accepted connection") + go serveRequests(cxn) + } + }() + return nil +} diff --git a/x/x.go b/x/x.go index a9e29e48..2ae25556 100644 --- a/x/x.go +++ b/x/x.go @@ -39,6 +39,12 @@ type DirectedEdge struct { Timestamp time.Time } +func SetError(prev *error, n error) { + if prev == nil { + prev = &n + } +} + func Log(p string) *logrus.Entry { l := logrus.WithFields(logrus.Fields{ "package": p, -- GitLab