diff --git a/conn/codec_test.go b/conn/codec_test.go new file mode 100644 index 0000000000000000000000000000000000000000..ce3fe0e7b2dae2799f004bcb06338d93a951ae4b --- /dev/null +++ b/conn/codec_test.go @@ -0,0 +1,145 @@ +package conn + +import ( + "bytes" + "net/rpc" + "testing" +) + +type buf struct { + data chan byte +} + +func newBuf() *buf { + b := new(buf) + b.data = make(chan byte, 10000) + return b +} + +func (b *buf) Read(p []byte) (n int, err error) { + for i := 0; i < len(p); i++ { + p[i] = <-b.data + } + return len(p), nil +} + +func (b *buf) Write(p []byte) (n int, err error) { + for i := 0; i < len(p); i++ { + b.data <- p[i] + } + return len(p), nil +} + +func (b *buf) Close() error { + close(b.data) + return nil +} + +func TestWriteAndParseHeader(t *testing.T) { + b := newBuf() + data := []byte("oh hey") + if err := writeHeader(b, 11, "testing.T", data); err != nil { + t.Error(err) + t.Fail() + } + var seq uint64 + var method string + var plen int32 + if err := parseHeader(b, &seq, &method, &plen); err != nil { + t.Error(err) + t.Fail() + } + if seq != 11 { + t.Error("Sequence number. Expected 11. Got: %v", seq) + } + if method != "testing.T" { + t.Error("Method name. Expected: testing.T. Got: %v", method) + } + if plen != int32(len(data)) { + t.Error("Payload length. Expected: %v. Got: %v", len(data), plen) + } +} + +func TestClientToServer(t *testing.T) { + b := newBuf() + cc := &ClientCodec{ + Rwc: b, + } + sc := &ServerCodec{ + Rwc: b, + } + + r := &rpc.Request{ + ServiceMethod: "Test.ClientServer", + Seq: 11, + } + + query := new(Query) + query.Data = []byte("iamaquery") + if err := cc.WriteRequest(r, query); err != nil { + t.Error(err) + } + + sr := new(rpc.Request) + if err := sc.ReadRequestHeader(sr); err != nil { + t.Error(err) + } + if sr.Seq != r.Seq { + t.Error("RPC Seq. Expected: %v. Got: %v", r.Seq, sr.Seq) + } + if sr.ServiceMethod != r.ServiceMethod { + t.Error("ServiceMethod. Expected: %v. Got: %v", + r.ServiceMethod, sr.ServiceMethod) + } + + squery := new(Query) + if err := sc.ReadRequestBody(squery); err != nil { + t.Error(err) + } + if !bytes.Equal(squery.Data, query.Data) { + t.Error("Queries don't match. Expected: %v Got: %v", + string(query.Data), string(squery.Data)) + } +} + +func TestServerToClient(t *testing.T) { + b := newBuf() + cc := &ClientCodec{ + Rwc: b, + } + sc := &ServerCodec{ + Rwc: b, + } + + r := &rpc.Response{ + ServiceMethod: "Test.ClientServer", + Seq: 11, + } + + reply := new(Reply) + reply.Data = []byte("iamareply") + if err := sc.WriteResponse(r, reply); err != nil { + t.Error(err) + } + + cr := new(rpc.Response) + if err := cc.ReadResponseHeader(cr); err != nil { + t.Error(err) + } + if cr.Seq != r.Seq { + t.Error("RPC Seq. Expected: %v. Got: %v", r.Seq, cr.Seq) + } + if cr.ServiceMethod != r.ServiceMethod { + t.Error("ServiceMethod. Expected: %v. Got: %v", + r.ServiceMethod, cr.ServiceMethod) + } + + creply := new(Reply) + if err := cc.ReadResponseBody(creply); err != nil { + t.Error(err) + } + if !bytes.Equal(creply.Data, reply.Data) { + t.Error("Replies don't match. Expected: %v Got: %v", + string(reply.Data), string(creply.Data)) + } +} diff --git a/conn/pool.go b/conn/pool.go index 7168c9cf3eb1f96dc89feea11a7c1c96d3cdbc3b..dd2c8d64986279416b24de26661c6f43c5f656ad 100644 --- a/conn/pool.go +++ b/conn/pool.go @@ -38,16 +38,17 @@ func (p *Pool) dialNew() (*rpc.Client, error) { return rpc.NewClientWithCodec(cc), nil } -func (p *Pool) Get() (*rpc.Client, error) { - select { - case client := <-p.clients: - return client, nil - default: - return p.dialNew() +func (p *Pool) Call(serviceMethod string, args interface{}, + reply interface{}) error { + + client, err := p.get() + if err != nil { + return err + } + if err = client.Call(serviceMethod, args, reply); err != nil { + return err } -} -func (p *Pool) Put(client *rpc.Client) error { select { case p.clients <- client: return nil @@ -56,9 +57,18 @@ func (p *Pool) Put(client *rpc.Client) error { } } +func (p *Pool) get() (*rpc.Client, error) { + select { + case client := <-p.clients: + return client, nil + default: + return p.dialNew() + } +} + func (p *Pool) Close() error { // We're not doing a clean exit here. A clean exit here would require - // mutex locks around conns; which seems unnecessary just to shut down - // the server. + // synchronization, which seems unnecessary for now. But, we should + // add one if required later. return nil } diff --git a/worker/worker.go b/worker/worker.go index 5966a632afe1286ea589ac28d7e7462a59525396..959d4161f851c59a679c0be4c01cb67b8bb1961b 100644 --- a/worker/worker.go +++ b/worker/worker.go @@ -44,14 +44,10 @@ func Connect() { continue } pool := conn.NewPool(addr, 5) - client, err := pool.Get() - if err != nil { - glog.Fatal(err) - } query := new(conn.Query) query.Data = []byte("hello") reply := new(conn.Reply) - if err = client.Call("Worker.Hello", query, reply); err != nil { + if err := pool.Call("Worker.Hello", query, reply); err != nil { glog.WithField("call", "Worker.Hello").Fatal(err) } glog.WithField("reply", string(reply.Data)).WithField("addr", addr).