From c27093d7f9e2cdb7b2f634f47b92cd4ee152b412 Mon Sep 17 00:00:00 2001 From: Manish R Jain <manishrjain@gmail.com> Date: Thu, 25 Feb 2016 17:14:18 +1100 Subject: [PATCH] Encapsulate Call within the client pool library, so we don't have to expose the clients. Also, write tests for conn module. --- conn/codec_test.go | 145 +++++++++++++++++++++++++++++++++++++++++++++ conn/pool.go | 30 ++++++---- worker/worker.go | 6 +- 3 files changed, 166 insertions(+), 15 deletions(-) create mode 100644 conn/codec_test.go diff --git a/conn/codec_test.go b/conn/codec_test.go new file mode 100644 index 00000000..ce3fe0e7 --- /dev/null +++ b/conn/codec_test.go @@ -0,0 +1,145 @@ +package conn + +import ( + "bytes" + "net/rpc" + "testing" +) + +type buf struct { + data chan byte +} + +func newBuf() *buf { + b := new(buf) + b.data = make(chan byte, 10000) + return b +} + +func (b *buf) Read(p []byte) (n int, err error) { + for i := 0; i < len(p); i++ { + p[i] = <-b.data + } + return len(p), nil +} + +func (b *buf) Write(p []byte) (n int, err error) { + for i := 0; i < len(p); i++ { + b.data <- p[i] + } + return len(p), nil +} + +func (b *buf) Close() error { + close(b.data) + return nil +} + +func TestWriteAndParseHeader(t *testing.T) { + b := newBuf() + data := []byte("oh hey") + if err := writeHeader(b, 11, "testing.T", data); err != nil { + t.Error(err) + t.Fail() + } + var seq uint64 + var method string + var plen int32 + if err := parseHeader(b, &seq, &method, &plen); err != nil { + t.Error(err) + t.Fail() + } + if seq != 11 { + t.Error("Sequence number. Expected 11. Got: %v", seq) + } + if method != "testing.T" { + t.Error("Method name. Expected: testing.T. Got: %v", method) + } + if plen != int32(len(data)) { + t.Error("Payload length. Expected: %v. Got: %v", len(data), plen) + } +} + +func TestClientToServer(t *testing.T) { + b := newBuf() + cc := &ClientCodec{ + Rwc: b, + } + sc := &ServerCodec{ + Rwc: b, + } + + r := &rpc.Request{ + ServiceMethod: "Test.ClientServer", + Seq: 11, + } + + query := new(Query) + query.Data = []byte("iamaquery") + if err := cc.WriteRequest(r, query); err != nil { + t.Error(err) + } + + sr := new(rpc.Request) + if err := sc.ReadRequestHeader(sr); err != nil { + t.Error(err) + } + if sr.Seq != r.Seq { + t.Error("RPC Seq. Expected: %v. Got: %v", r.Seq, sr.Seq) + } + if sr.ServiceMethod != r.ServiceMethod { + t.Error("ServiceMethod. Expected: %v. Got: %v", + r.ServiceMethod, sr.ServiceMethod) + } + + squery := new(Query) + if err := sc.ReadRequestBody(squery); err != nil { + t.Error(err) + } + if !bytes.Equal(squery.Data, query.Data) { + t.Error("Queries don't match. Expected: %v Got: %v", + string(query.Data), string(squery.Data)) + } +} + +func TestServerToClient(t *testing.T) { + b := newBuf() + cc := &ClientCodec{ + Rwc: b, + } + sc := &ServerCodec{ + Rwc: b, + } + + r := &rpc.Response{ + ServiceMethod: "Test.ClientServer", + Seq: 11, + } + + reply := new(Reply) + reply.Data = []byte("iamareply") + if err := sc.WriteResponse(r, reply); err != nil { + t.Error(err) + } + + cr := new(rpc.Response) + if err := cc.ReadResponseHeader(cr); err != nil { + t.Error(err) + } + if cr.Seq != r.Seq { + t.Error("RPC Seq. Expected: %v. Got: %v", r.Seq, cr.Seq) + } + if cr.ServiceMethod != r.ServiceMethod { + t.Error("ServiceMethod. Expected: %v. Got: %v", + r.ServiceMethod, cr.ServiceMethod) + } + + creply := new(Reply) + if err := cc.ReadResponseBody(creply); err != nil { + t.Error(err) + } + if !bytes.Equal(creply.Data, reply.Data) { + t.Error("Replies don't match. Expected: %v Got: %v", + string(reply.Data), string(creply.Data)) + } +} diff --git a/conn/pool.go b/conn/pool.go index 7168c9cf..dd2c8d64 100644 --- a/conn/pool.go +++ b/conn/pool.go @@ -38,16 +38,17 @@ func (p *Pool) dialNew() (*rpc.Client, error) { return rpc.NewClientWithCodec(cc), nil } -func (p *Pool) Get() (*rpc.Client, error) { - select { - case client := <-p.clients: - return client, nil - default: - return p.dialNew() +func (p *Pool) Call(serviceMethod string, args interface{}, + reply interface{}) error { + + client, err := p.get() + if err != nil { + return err + } + if err = client.Call(serviceMethod, args, reply); err != nil { + return err } -} -func (p *Pool) Put(client *rpc.Client) error { select { case p.clients <- client: return nil @@ -56,9 +57,18 @@ func (p *Pool) Put(client *rpc.Client) error { } } +func (p *Pool) get() (*rpc.Client, error) { + select { + case client := <-p.clients: + return client, nil + default: + return p.dialNew() + } +} + func (p *Pool) Close() error { // We're not doing a clean exit here. A clean exit here would require - // mutex locks around conns; which seems unnecessary just to shut down - // the server. + // synchronization, which seems unnecessary for now. But, we should + // add one if required later. return nil } diff --git a/worker/worker.go b/worker/worker.go index 5966a632..959d4161 100644 --- a/worker/worker.go +++ b/worker/worker.go @@ -44,14 +44,10 @@ func Connect() { continue } pool := conn.NewPool(addr, 5) - client, err := pool.Get() - if err != nil { - glog.Fatal(err) - } query := new(conn.Query) query.Data = []byte("hello") reply := new(conn.Reply) - if err = client.Call("Worker.Hello", query, reply); err != nil { + if err := pool.Call("Worker.Hello", query, reply); err != nil { glog.WithField("call", "Worker.Hello").Fatal(err) } glog.WithField("reply", string(reply.Data)).WithField("addr", addr). -- GitLab