Skip to content
Snippets Groups Projects
Commit c27093d7 authored by Manish R Jain's avatar Manish R Jain
Browse files

Encapsulate Call within the client pool library, so we don't have to expose...

Encapsulate Call within the client pool library, so we don't have to expose the clients. Also, write tests for conn module.
parent 2a5f0529
No related branches found
No related tags found
No related merge requests found
package conn
import (
"bytes"
"net/rpc"
"testing"
)
type buf struct {
data chan byte
}
func newBuf() *buf {
b := new(buf)
b.data = make(chan byte, 10000)
return b
}
func (b *buf) Read(p []byte) (n int, err error) {
for i := 0; i < len(p); i++ {
p[i] = <-b.data
}
return len(p), nil
}
func (b *buf) Write(p []byte) (n int, err error) {
for i := 0; i < len(p); i++ {
b.data <- p[i]
}
return len(p), nil
}
func (b *buf) Close() error {
close(b.data)
return nil
}
func TestWriteAndParseHeader(t *testing.T) {
b := newBuf()
data := []byte("oh hey")
if err := writeHeader(b, 11, "testing.T", data); err != nil {
t.Error(err)
t.Fail()
}
var seq uint64
var method string
var plen int32
if err := parseHeader(b, &seq, &method, &plen); err != nil {
t.Error(err)
t.Fail()
}
if seq != 11 {
t.Error("Sequence number. Expected 11. Got: %v", seq)
}
if method != "testing.T" {
t.Error("Method name. Expected: testing.T. Got: %v", method)
}
if plen != int32(len(data)) {
t.Error("Payload length. Expected: %v. Got: %v", len(data), plen)
}
}
func TestClientToServer(t *testing.T) {
b := newBuf()
cc := &ClientCodec{
Rwc: b,
}
sc := &ServerCodec{
Rwc: b,
}
r := &rpc.Request{
ServiceMethod: "Test.ClientServer",
Seq: 11,
}
query := new(Query)
query.Data = []byte("iamaquery")
if err := cc.WriteRequest(r, query); err != nil {
t.Error(err)
}
sr := new(rpc.Request)
if err := sc.ReadRequestHeader(sr); err != nil {
t.Error(err)
}
if sr.Seq != r.Seq {
t.Error("RPC Seq. Expected: %v. Got: %v", r.Seq, sr.Seq)
}
if sr.ServiceMethod != r.ServiceMethod {
t.Error("ServiceMethod. Expected: %v. Got: %v",
r.ServiceMethod, sr.ServiceMethod)
}
squery := new(Query)
if err := sc.ReadRequestBody(squery); err != nil {
t.Error(err)
}
if !bytes.Equal(squery.Data, query.Data) {
t.Error("Queries don't match. Expected: %v Got: %v",
string(query.Data), string(squery.Data))
}
}
func TestServerToClient(t *testing.T) {
b := newBuf()
cc := &ClientCodec{
Rwc: b,
}
sc := &ServerCodec{
Rwc: b,
}
r := &rpc.Response{
ServiceMethod: "Test.ClientServer",
Seq: 11,
}
reply := new(Reply)
reply.Data = []byte("iamareply")
if err := sc.WriteResponse(r, reply); err != nil {
t.Error(err)
}
cr := new(rpc.Response)
if err := cc.ReadResponseHeader(cr); err != nil {
t.Error(err)
}
if cr.Seq != r.Seq {
t.Error("RPC Seq. Expected: %v. Got: %v", r.Seq, cr.Seq)
}
if cr.ServiceMethod != r.ServiceMethod {
t.Error("ServiceMethod. Expected: %v. Got: %v",
r.ServiceMethod, cr.ServiceMethod)
}
creply := new(Reply)
if err := cc.ReadResponseBody(creply); err != nil {
t.Error(err)
}
if !bytes.Equal(creply.Data, reply.Data) {
t.Error("Replies don't match. Expected: %v Got: %v",
string(reply.Data), string(creply.Data))
}
}
...@@ -38,16 +38,17 @@ func (p *Pool) dialNew() (*rpc.Client, error) { ...@@ -38,16 +38,17 @@ func (p *Pool) dialNew() (*rpc.Client, error) {
return rpc.NewClientWithCodec(cc), nil return rpc.NewClientWithCodec(cc), nil
} }
func (p *Pool) Get() (*rpc.Client, error) { func (p *Pool) Call(serviceMethod string, args interface{},
select { reply interface{}) error {
case client := <-p.clients:
return client, nil client, err := p.get()
default: if err != nil {
return p.dialNew() return err
}
if err = client.Call(serviceMethod, args, reply); err != nil {
return err
} }
}
func (p *Pool) Put(client *rpc.Client) error {
select { select {
case p.clients <- client: case p.clients <- client:
return nil return nil
...@@ -56,9 +57,18 @@ func (p *Pool) Put(client *rpc.Client) error { ...@@ -56,9 +57,18 @@ func (p *Pool) Put(client *rpc.Client) error {
} }
} }
func (p *Pool) get() (*rpc.Client, error) {
select {
case client := <-p.clients:
return client, nil
default:
return p.dialNew()
}
}
func (p *Pool) Close() error { func (p *Pool) Close() error {
// We're not doing a clean exit here. A clean exit here would require // We're not doing a clean exit here. A clean exit here would require
// mutex locks around conns; which seems unnecessary just to shut down // synchronization, which seems unnecessary for now. But, we should
// the server. // add one if required later.
return nil return nil
} }
...@@ -44,14 +44,10 @@ func Connect() { ...@@ -44,14 +44,10 @@ func Connect() {
continue continue
} }
pool := conn.NewPool(addr, 5) pool := conn.NewPool(addr, 5)
client, err := pool.Get()
if err != nil {
glog.Fatal(err)
}
query := new(conn.Query) query := new(conn.Query)
query.Data = []byte("hello") query.Data = []byte("hello")
reply := new(conn.Reply) reply := new(conn.Reply)
if err = client.Call("Worker.Hello", query, reply); err != nil { if err := pool.Call("Worker.Hello", query, reply); err != nil {
glog.WithField("call", "Worker.Hello").Fatal(err) glog.WithField("call", "Worker.Hello").Fatal(err)
} }
glog.WithField("reply", string(reply.Data)).WithField("addr", addr). glog.WithField("reply", string(reply.Data)).WithField("addr", addr).
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment