diff --git a/conn/codec_test.go b/conn/codec_test.go new file mode 100644 index 0000000000000000000000000000000000000000..ce3fe0e7b2dae2799f004bcb06338d93a951ae4b --- /dev/null +++ b/conn/codec_test.go @@ -0,0 +1,145 @@ +package conn + +import ( + "bytes" + "net/rpc" + "testing" +) + +type buf struct { + data chan byte +} + +func newBuf() *buf { + b := new(buf) + b.data = make(chan byte, 10000) + return b +} + +func (b *buf) Read(p []byte) (n int, err error) { + for i := 0; i < len(p); i++ { + p[i] = <-b.data + } + return len(p), nil +} + +func (b *buf) Write(p []byte) (n int, err error) { + for i := 0; i < len(p); i++ { + b.data <- p[i] + } + return len(p), nil +} + +func (b *buf) Close() error { + close(b.data) + return nil +} + +func TestWriteAndParseHeader(t *testing.T) { + b := newBuf() + data := []byte("oh hey") + if err := writeHeader(b, 11, "testing.T", data); err != nil { + t.Error(err) + t.Fail() + } + var seq uint64 + var method string + var plen int32 + if err := parseHeader(b, &seq, &method, &plen); err != nil { + t.Error(err) + t.Fail() + } + if seq != 11 { + t.Error("Sequence number. Expected 11. Got: %v", seq) + } + if method != "testing.T" { + t.Error("Method name. Expected: testing.T. Got: %v", method) + } + if plen != int32(len(data)) { + t.Error("Payload length. Expected: %v. Got: %v", len(data), plen) + } +} + +func TestClientToServer(t *testing.T) { + b := newBuf() + cc := &ClientCodec{ + Rwc: b, + } + sc := &ServerCodec{ + Rwc: b, + } + + r := &rpc.Request{ + ServiceMethod: "Test.ClientServer", + Seq: 11, + } + + query := new(Query) + query.Data = []byte("iamaquery") + if err := cc.WriteRequest(r, query); err != nil { + t.Error(err) + } + + sr := new(rpc.Request) + if err := sc.ReadRequestHeader(sr); err != nil { + t.Error(err) + } + if sr.Seq != r.Seq { + t.Error("RPC Seq. Expected: %v. Got: %v", r.Seq, sr.Seq) + } + if sr.ServiceMethod != r.ServiceMethod { + t.Error("ServiceMethod. Expected: %v. Got: %v", + r.ServiceMethod, sr.ServiceMethod) + } + + squery := new(Query) + if err := sc.ReadRequestBody(squery); err != nil { + t.Error(err) + } + if !bytes.Equal(squery.Data, query.Data) { + t.Error("Queries don't match. Expected: %v Got: %v", + string(query.Data), string(squery.Data)) + } +} + +func TestServerToClient(t *testing.T) { + b := newBuf() + cc := &ClientCodec{ + Rwc: b, + } + sc := &ServerCodec{ + Rwc: b, + } + + r := &rpc.Response{ + ServiceMethod: "Test.ClientServer", + Seq: 11, + } + + reply := new(Reply) + reply.Data = []byte("iamareply") + if err := sc.WriteResponse(r, reply); err != nil { + t.Error(err) + } + + cr := new(rpc.Response) + if err := cc.ReadResponseHeader(cr); err != nil { + t.Error(err) + } + if cr.Seq != r.Seq { + t.Error("RPC Seq. Expected: %v. Got: %v", r.Seq, cr.Seq) + } + if cr.ServiceMethod != r.ServiceMethod { + t.Error("ServiceMethod. Expected: %v. Got: %v", + r.ServiceMethod, cr.ServiceMethod) + } + + creply := new(Reply) + if err := cc.ReadResponseBody(creply); err != nil { + t.Error(err) + } + if !bytes.Equal(creply.Data, reply.Data) { + t.Error("Replies don't match. Expected: %v Got: %v", + string(reply.Data), string(creply.Data)) + } +} diff --git a/conn/pool.go b/conn/pool.go index 7168c9cf3eb1f96dc89feea11a7c1c96d3cdbc3b..b8aa13057840b02c2db6687e4a04be52194c1d7e 100644 --- a/conn/pool.go +++ b/conn/pool.go @@ -3,6 +3,8 @@ package conn import ( "net" "net/rpc" + "strings" + "time" "github.com/dgraph-io/dgraph/x" ) @@ -28,7 +30,24 @@ func NewPool(addr string, maxCap int) *Pool { } func (p *Pool) dialNew() (*rpc.Client, error) { - nconn, err := net.Dial("tcp", p.addr) + d := &net.Dialer{ + Timeout: 3 * time.Minute, + } + var nconn net.Conn + var err error + for i := 0; i < 10; i++ { + nconn, err = d.Dial("tcp", p.addr) + if err == nil { + break + } + if !strings.Contains(err.Error(), "refused") { + break + } + + glog.WithField("error", err).WithField("addr", p.addr). + Info("Retrying connection...") + time.Sleep(10 * time.Second) + } if err != nil { return nil, err } @@ -38,16 +57,17 @@ func (p *Pool) dialNew() (*rpc.Client, error) { return rpc.NewClientWithCodec(cc), nil } -func (p *Pool) Get() (*rpc.Client, error) { - select { - case client := <-p.clients: - return client, nil - default: - return p.dialNew() +func (p *Pool) Call(serviceMethod string, args interface{}, + reply interface{}) error { + + client, err := p.get() + if err != nil { + return err + } + if err = client.Call(serviceMethod, args, reply); err != nil { + return err } -} -func (p *Pool) Put(client *rpc.Client) error { select { case p.clients <- client: return nil @@ -56,9 +76,18 @@ func (p *Pool) Put(client *rpc.Client) error { } } +func (p *Pool) get() (*rpc.Client, error) { + select { + case client := <-p.clients: + return client, nil + default: + return p.dialNew() + } +} + func (p *Pool) Close() error { // We're not doing a clean exit here. A clean exit here would require - // mutex locks around conns; which seems unnecessary just to shut down - // the server. + // synchronization, which seems unnecessary for now. But, we should + // add one if required later. return nil } diff --git a/worker/worker.go b/worker/worker.go index 5966a632afe1286ea589ac28d7e7462a59525396..9b6e2b91fed605a92d587e137a94c1982cbd449f 100644 --- a/worker/worker.go +++ b/worker/worker.go @@ -38,20 +38,15 @@ func Connect() { } addrs := strings.Split(*workers, ",") - var pools []*conn.Pool for _, addr := range addrs { if len(addr) == 0 { continue } pool := conn.NewPool(addr, 5) - client, err := pool.Get() - if err != nil { - glog.Fatal(err) - } query := new(conn.Query) query.Data = []byte("hello") reply := new(conn.Reply) - if err = client.Call("Worker.Hello", query, reply); err != nil { + if err := pool.Call("Worker.Hello", query, reply); err != nil { glog.WithField("call", "Worker.Hello").Fatal(err) } glog.WithField("reply", string(reply.Data)).WithField("addr", addr).