Skip to content
Snippets Groups Projects
Commit a477eed5 authored by Manish R Jain's avatar Manish R Jain
Browse files

Merge pull request #39 from dgraph-io/dmuts

Encapsulate Call within the client pool library, so we don't have to …
parents 617e1321 9c8b2205
No related branches found
No related tags found
No related merge requests found
package conn
import (
"bytes"
"net/rpc"
"testing"
)
type buf struct {
data chan byte
}
func newBuf() *buf {
b := new(buf)
b.data = make(chan byte, 10000)
return b
}
func (b *buf) Read(p []byte) (n int, err error) {
for i := 0; i < len(p); i++ {
p[i] = <-b.data
}
return len(p), nil
}
func (b *buf) Write(p []byte) (n int, err error) {
for i := 0; i < len(p); i++ {
b.data <- p[i]
}
return len(p), nil
}
func (b *buf) Close() error {
close(b.data)
return nil
}
func TestWriteAndParseHeader(t *testing.T) {
b := newBuf()
data := []byte("oh hey")
if err := writeHeader(b, 11, "testing.T", data); err != nil {
t.Error(err)
t.Fail()
}
var seq uint64
var method string
var plen int32
if err := parseHeader(b, &seq, &method, &plen); err != nil {
t.Error(err)
t.Fail()
}
if seq != 11 {
t.Error("Sequence number. Expected 11. Got: %v", seq)
}
if method != "testing.T" {
t.Error("Method name. Expected: testing.T. Got: %v", method)
}
if plen != int32(len(data)) {
t.Error("Payload length. Expected: %v. Got: %v", len(data), plen)
}
}
func TestClientToServer(t *testing.T) {
b := newBuf()
cc := &ClientCodec{
Rwc: b,
}
sc := &ServerCodec{
Rwc: b,
}
r := &rpc.Request{
ServiceMethod: "Test.ClientServer",
Seq: 11,
}
query := new(Query)
query.Data = []byte("iamaquery")
if err := cc.WriteRequest(r, query); err != nil {
t.Error(err)
}
sr := new(rpc.Request)
if err := sc.ReadRequestHeader(sr); err != nil {
t.Error(err)
}
if sr.Seq != r.Seq {
t.Error("RPC Seq. Expected: %v. Got: %v", r.Seq, sr.Seq)
}
if sr.ServiceMethod != r.ServiceMethod {
t.Error("ServiceMethod. Expected: %v. Got: %v",
r.ServiceMethod, sr.ServiceMethod)
}
squery := new(Query)
if err := sc.ReadRequestBody(squery); err != nil {
t.Error(err)
}
if !bytes.Equal(squery.Data, query.Data) {
t.Error("Queries don't match. Expected: %v Got: %v",
string(query.Data), string(squery.Data))
}
}
func TestServerToClient(t *testing.T) {
b := newBuf()
cc := &ClientCodec{
Rwc: b,
}
sc := &ServerCodec{
Rwc: b,
}
r := &rpc.Response{
ServiceMethod: "Test.ClientServer",
Seq: 11,
}
reply := new(Reply)
reply.Data = []byte("iamareply")
if err := sc.WriteResponse(r, reply); err != nil {
t.Error(err)
}
cr := new(rpc.Response)
if err := cc.ReadResponseHeader(cr); err != nil {
t.Error(err)
}
if cr.Seq != r.Seq {
t.Error("RPC Seq. Expected: %v. Got: %v", r.Seq, cr.Seq)
}
if cr.ServiceMethod != r.ServiceMethod {
t.Error("ServiceMethod. Expected: %v. Got: %v",
r.ServiceMethod, cr.ServiceMethod)
}
creply := new(Reply)
if err := cc.ReadResponseBody(creply); err != nil {
t.Error(err)
}
if !bytes.Equal(creply.Data, reply.Data) {
t.Error("Replies don't match. Expected: %v Got: %v",
string(reply.Data), string(creply.Data))
}
}
......@@ -3,6 +3,8 @@ package conn
import (
"net"
"net/rpc"
"strings"
"time"
"github.com/dgraph-io/dgraph/x"
)
......@@ -28,7 +30,24 @@ func NewPool(addr string, maxCap int) *Pool {
}
func (p *Pool) dialNew() (*rpc.Client, error) {
nconn, err := net.Dial("tcp", p.addr)
d := &net.Dialer{
Timeout: 3 * time.Minute,
}
var nconn net.Conn
var err error
for i := 0; i < 10; i++ {
nconn, err = d.Dial("tcp", p.addr)
if err == nil {
break
}
if !strings.Contains(err.Error(), "refused") {
break
}
glog.WithField("error", err).WithField("addr", p.addr).
Info("Retrying connection...")
time.Sleep(10 * time.Second)
}
if err != nil {
return nil, err
}
......@@ -38,16 +57,17 @@ func (p *Pool) dialNew() (*rpc.Client, error) {
return rpc.NewClientWithCodec(cc), nil
}
func (p *Pool) Get() (*rpc.Client, error) {
select {
case client := <-p.clients:
return client, nil
default:
return p.dialNew()
func (p *Pool) Call(serviceMethod string, args interface{},
reply interface{}) error {
client, err := p.get()
if err != nil {
return err
}
if err = client.Call(serviceMethod, args, reply); err != nil {
return err
}
}
func (p *Pool) Put(client *rpc.Client) error {
select {
case p.clients <- client:
return nil
......@@ -56,9 +76,18 @@ func (p *Pool) Put(client *rpc.Client) error {
}
}
func (p *Pool) get() (*rpc.Client, error) {
select {
case client := <-p.clients:
return client, nil
default:
return p.dialNew()
}
}
func (p *Pool) Close() error {
// We're not doing a clean exit here. A clean exit here would require
// mutex locks around conns; which seems unnecessary just to shut down
// the server.
// synchronization, which seems unnecessary for now. But, we should
// add one if required later.
return nil
}
......@@ -38,20 +38,15 @@ func Connect() {
}
addrs := strings.Split(*workers, ",")
var pools []*conn.Pool
for _, addr := range addrs {
if len(addr) == 0 {
continue
}
pool := conn.NewPool(addr, 5)
client, err := pool.Get()
if err != nil {
glog.Fatal(err)
}
query := new(conn.Query)
query.Data = []byte("hello")
reply := new(conn.Reply)
if err = client.Call("Worker.Hello", query, reply); err != nil {
if err := pool.Call("Worker.Hello", query, reply); err != nil {
glog.WithField("call", "Worker.Hello").Fatal(err)
}
glog.WithField("reply", string(reply.Data)).WithField("addr", addr).
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment