Skip to content
Snippets Groups Projects
Commit 0a5f6679 authored by Ashwin's avatar Ashwin
Browse files

Merge branch 'master' into distributed

parents 3e917ef0 a477eed5
No related branches found
No related tags found
No related merge requests found
package conn
import (
"bytes"
"net/rpc"
"testing"
)
type buf struct {
data chan byte
}
func newBuf() *buf {
b := new(buf)
b.data = make(chan byte, 10000)
return b
}
func (b *buf) Read(p []byte) (n int, err error) {
for i := 0; i < len(p); i++ {
p[i] = <-b.data
}
return len(p), nil
}
func (b *buf) Write(p []byte) (n int, err error) {
for i := 0; i < len(p); i++ {
b.data <- p[i]
}
return len(p), nil
}
func (b *buf) Close() error {
close(b.data)
return nil
}
func TestWriteAndParseHeader(t *testing.T) {
b := newBuf()
data := []byte("oh hey")
if err := writeHeader(b, 11, "testing.T", data); err != nil {
t.Error(err)
t.Fail()
}
var seq uint64
var method string
var plen int32
if err := parseHeader(b, &seq, &method, &plen); err != nil {
t.Error(err)
t.Fail()
}
if seq != 11 {
t.Error("Sequence number. Expected 11. Got: %v", seq)
}
if method != "testing.T" {
t.Error("Method name. Expected: testing.T. Got: %v", method)
}
if plen != int32(len(data)) {
t.Error("Payload length. Expected: %v. Got: %v", len(data), plen)
}
}
func TestClientToServer(t *testing.T) {
b := newBuf()
cc := &ClientCodec{
Rwc: b,
}
sc := &ServerCodec{
Rwc: b,
}
r := &rpc.Request{
ServiceMethod: "Test.ClientServer",
Seq: 11,
}
query := new(Query)
query.Data = []byte("iamaquery")
if err := cc.WriteRequest(r, query); err != nil {
t.Error(err)
}
sr := new(rpc.Request)
if err := sc.ReadRequestHeader(sr); err != nil {
t.Error(err)
}
if sr.Seq != r.Seq {
t.Error("RPC Seq. Expected: %v. Got: %v", r.Seq, sr.Seq)
}
if sr.ServiceMethod != r.ServiceMethod {
t.Error("ServiceMethod. Expected: %v. Got: %v",
r.ServiceMethod, sr.ServiceMethod)
}
squery := new(Query)
if err := sc.ReadRequestBody(squery); err != nil {
t.Error(err)
}
if !bytes.Equal(squery.Data, query.Data) {
t.Error("Queries don't match. Expected: %v Got: %v",
string(query.Data), string(squery.Data))
}
}
func TestServerToClient(t *testing.T) {
b := newBuf()
cc := &ClientCodec{
Rwc: b,
}
sc := &ServerCodec{
Rwc: b,
}
r := &rpc.Response{
ServiceMethod: "Test.ClientServer",
Seq: 11,
}
reply := new(Reply)
reply.Data = []byte("iamareply")
if err := sc.WriteResponse(r, reply); err != nil {
t.Error(err)
}
cr := new(rpc.Response)
if err := cc.ReadResponseHeader(cr); err != nil {
t.Error(err)
}
if cr.Seq != r.Seq {
t.Error("RPC Seq. Expected: %v. Got: %v", r.Seq, cr.Seq)
}
if cr.ServiceMethod != r.ServiceMethod {
t.Error("ServiceMethod. Expected: %v. Got: %v",
r.ServiceMethod, cr.ServiceMethod)
}
creply := new(Reply)
if err := cc.ReadResponseBody(creply); err != nil {
t.Error(err)
}
if !bytes.Equal(creply.Data, reply.Data) {
t.Error("Replies don't match. Expected: %v Got: %v",
string(reply.Data), string(creply.Data))
}
}
...@@ -3,6 +3,8 @@ package conn ...@@ -3,6 +3,8 @@ package conn
import ( import (
"net" "net"
"net/rpc" "net/rpc"
"strings"
"time"
"github.com/dgraph-io/dgraph/x" "github.com/dgraph-io/dgraph/x"
) )
...@@ -28,7 +30,24 @@ func NewPool(addr string, maxCap int) *Pool { ...@@ -28,7 +30,24 @@ func NewPool(addr string, maxCap int) *Pool {
} }
func (p *Pool) dialNew() (*rpc.Client, error) { func (p *Pool) dialNew() (*rpc.Client, error) {
nconn, err := net.Dial("tcp", p.addr) d := &net.Dialer{
Timeout: 3 * time.Minute,
}
var nconn net.Conn
var err error
for i := 0; i < 10; i++ {
nconn, err = d.Dial("tcp", p.addr)
if err == nil {
break
}
if !strings.Contains(err.Error(), "refused") {
break
}
glog.WithField("error", err).WithField("addr", p.addr).
Info("Retrying connection...")
time.Sleep(10 * time.Second)
}
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -38,16 +57,17 @@ func (p *Pool) dialNew() (*rpc.Client, error) { ...@@ -38,16 +57,17 @@ func (p *Pool) dialNew() (*rpc.Client, error) {
return rpc.NewClientWithCodec(cc), nil return rpc.NewClientWithCodec(cc), nil
} }
func (p *Pool) Get() (*rpc.Client, error) { func (p *Pool) Call(serviceMethod string, args interface{},
select { reply interface{}) error {
case client := <-p.clients:
return client, nil client, err := p.get()
default: if err != nil {
return p.dialNew() return err
}
if err = client.Call(serviceMethod, args, reply); err != nil {
return err
} }
}
func (p *Pool) Put(client *rpc.Client) error {
select { select {
case p.clients <- client: case p.clients <- client:
return nil return nil
...@@ -56,9 +76,18 @@ func (p *Pool) Put(client *rpc.Client) error { ...@@ -56,9 +76,18 @@ func (p *Pool) Put(client *rpc.Client) error {
} }
} }
func (p *Pool) get() (*rpc.Client, error) {
select {
case client := <-p.clients:
return client, nil
default:
return p.dialNew()
}
}
func (p *Pool) Close() error { func (p *Pool) Close() error {
// We're not doing a clean exit here. A clean exit here would require // We're not doing a clean exit here. A clean exit here would require
// mutex locks around conns; which seems unnecessary just to shut down // synchronization, which seems unnecessary for now. But, we should
// the server. // add one if required later.
return nil return nil
} }
...@@ -170,7 +170,7 @@ func postTraverse(g *SubGraph) (result map[uint64]interface{}, rerr error) { ...@@ -170,7 +170,7 @@ func postTraverse(g *SubGraph) (result map[uint64]interface{}, rerr error) {
r.Init(g.result, ro) r.Init(g.result, ro)
if q.UidsLength() != r.UidmatrixLength() { if q.UidsLength() != r.UidmatrixLength() {
glog.Fatal("Result uidmatrixlength: %v. Query uidslength: %v", glog.Fatalf("Result uidmatrixlength: %v. Query uidslength: %v",
r.UidmatrixLength(), q.UidsLength()) r.UidmatrixLength(), q.UidsLength())
} }
if q.UidsLength() != r.ValuesLength() { if q.UidsLength() != r.ValuesLength() {
......
...@@ -81,7 +81,7 @@ func checkSingleValue(t *testing.T, child *SubGraph, ...@@ -81,7 +81,7 @@ func checkSingleValue(t *testing.T, child *SubGraph,
} }
if ul.UidsLength() != 0 { if ul.UidsLength() != 0 {
t.Error("Expected uids length 0. Got: %v", ul.UidsLength()) t.Errorf("Expected uids length 0. Got: %v", ul.UidsLength())
} }
checkName(t, r, 0, value) checkName(t, r, 0, value)
} }
......
...@@ -39,20 +39,15 @@ func Connect() { ...@@ -39,20 +39,15 @@ func Connect() {
} }
addrs := strings.Split(*workers, ",") addrs := strings.Split(*workers, ",")
var pools []*conn.Pool
for _, addr := range addrs { for _, addr := range addrs {
if len(addr) == 0 { if len(addr) == 0 {
continue continue
} }
pool := conn.NewPool(addr, 5) pool := conn.NewPool(addr, 5)
client, err := pool.Get()
if err != nil {
glog.Fatal(err)
}
query := new(conn.Query) query := new(conn.Query)
query.Data = []byte("hello") query.Data = []byte("hello")
reply := new(conn.Reply) reply := new(conn.Reply)
if err = client.Call("Worker.Hello", query, reply); err != nil { if err := pool.Call("Worker.Hello", query, reply); err != nil {
glog.WithField("call", "Worker.Hello").Fatal(err) glog.WithField("call", "Worker.Hello").Fatal(err)
} }
glog.WithField("reply", string(reply.Data)).WithField("addr", addr). glog.WithField("reply", string(reply.Data)).WithField("addr", addr).
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment