diff --git a/:q b/:q deleted file mode 100644 index ba6f93ab838b6b5168f8507a27bba976085c0a3e..0000000000000000000000000000000000000000 --- a/:q +++ /dev/null @@ -1,5 +0,0 @@ -worker/worker.go|30 col 62| : expected type, found ')' -worker/worker.go|31 col 12| : expected ';', found '=' -worker/worker.go|39 col 2| : expected declaration, found 'if' -worker/worker.go|48 col 2| : expected declaration, found 'for' -worker/worker.go|107 col 4| : expected declaration, found 'if' diff --git a/conn/codec_test.go b/conn/codec_test.go new file mode 100644 index 0000000000000000000000000000000000000000..ce3fe0e7b2dae2799f004bcb06338d93a951ae4b --- /dev/null +++ b/conn/codec_test.go @@ -0,0 +1,145 @@ +package conn + +import ( + "bytes" + "net/rpc" + "testing" +) + +type buf struct { + data chan byte +} + +func newBuf() *buf { + b := new(buf) + b.data = make(chan byte, 10000) + return b +} + +func (b *buf) Read(p []byte) (n int, err error) { + for i := 0; i < len(p); i++ { + p[i] = <-b.data + } + return len(p), nil +} + +func (b *buf) Write(p []byte) (n int, err error) { + for i := 0; i < len(p); i++ { + b.data <- p[i] + } + return len(p), nil +} + +func (b *buf) Close() error { + close(b.data) + return nil +} + +func TestWriteAndParseHeader(t *testing.T) { + b := newBuf() + data := []byte("oh hey") + if err := writeHeader(b, 11, "testing.T", data); err != nil { + t.Error(err) + t.Fail() + } + var seq uint64 + var method string + var plen int32 + if err := parseHeader(b, &seq, &method, &plen); err != nil { + t.Error(err) + t.Fail() + } + if seq != 11 { + t.Error("Sequence number. Expected 11. Got: %v", seq) + } + if method != "testing.T" { + t.Error("Method name. Expected: testing.T. Got: %v", method) + } + if plen != int32(len(data)) { + t.Error("Payload length. Expected: %v. Got: %v", len(data), plen) + } +} + +func TestClientToServer(t *testing.T) { + b := newBuf() + cc := &ClientCodec{ + Rwc: b, + } + sc := &ServerCodec{ + Rwc: b, + } + + r := &rpc.Request{ + ServiceMethod: "Test.ClientServer", + Seq: 11, + } + + query := new(Query) + query.Data = []byte("iamaquery") + if err := cc.WriteRequest(r, query); err != nil { + t.Error(err) + } + + sr := new(rpc.Request) + if err := sc.ReadRequestHeader(sr); err != nil { + t.Error(err) + } + if sr.Seq != r.Seq { + t.Error("RPC Seq. Expected: %v. Got: %v", r.Seq, sr.Seq) + } + if sr.ServiceMethod != r.ServiceMethod { + t.Error("ServiceMethod. Expected: %v. Got: %v", + r.ServiceMethod, sr.ServiceMethod) + } + + squery := new(Query) + if err := sc.ReadRequestBody(squery); err != nil { + t.Error(err) + } + if !bytes.Equal(squery.Data, query.Data) { + t.Error("Queries don't match. Expected: %v Got: %v", + string(query.Data), string(squery.Data)) + } +} + +func TestServerToClient(t *testing.T) { + b := newBuf() + cc := &ClientCodec{ + Rwc: b, + } + sc := &ServerCodec{ + Rwc: b, + } + + r := &rpc.Response{ + ServiceMethod: "Test.ClientServer", + Seq: 11, + } + + reply := new(Reply) + reply.Data = []byte("iamareply") + if err := sc.WriteResponse(r, reply); err != nil { + t.Error(err) + } + + cr := new(rpc.Response) + if err := cc.ReadResponseHeader(cr); err != nil { + t.Error(err) + } + if cr.Seq != r.Seq { + t.Error("RPC Seq. Expected: %v. Got: %v", r.Seq, cr.Seq) + } + if cr.ServiceMethod != r.ServiceMethod { + t.Error("ServiceMethod. Expected: %v. Got: %v", + r.ServiceMethod, cr.ServiceMethod) + } + + creply := new(Reply) + if err := cc.ReadResponseBody(creply); err != nil { + t.Error(err) + } + if !bytes.Equal(creply.Data, reply.Data) { + t.Error("Replies don't match. Expected: %v Got: %v", + string(reply.Data), string(creply.Data)) + } +} diff --git a/conn/pool.go b/conn/pool.go index 7168c9cf3eb1f96dc89feea11a7c1c96d3cdbc3b..b8aa13057840b02c2db6687e4a04be52194c1d7e 100644 --- a/conn/pool.go +++ b/conn/pool.go @@ -3,6 +3,8 @@ package conn import ( "net" "net/rpc" + "strings" + "time" "github.com/dgraph-io/dgraph/x" ) @@ -28,7 +30,24 @@ func NewPool(addr string, maxCap int) *Pool { } func (p *Pool) dialNew() (*rpc.Client, error) { - nconn, err := net.Dial("tcp", p.addr) + d := &net.Dialer{ + Timeout: 3 * time.Minute, + } + var nconn net.Conn + var err error + for i := 0; i < 10; i++ { + nconn, err = d.Dial("tcp", p.addr) + if err == nil { + break + } + if !strings.Contains(err.Error(), "refused") { + break + } + + glog.WithField("error", err).WithField("addr", p.addr). + Info("Retrying connection...") + time.Sleep(10 * time.Second) + } if err != nil { return nil, err } @@ -38,16 +57,17 @@ func (p *Pool) dialNew() (*rpc.Client, error) { return rpc.NewClientWithCodec(cc), nil } -func (p *Pool) Get() (*rpc.Client, error) { - select { - case client := <-p.clients: - return client, nil - default: - return p.dialNew() +func (p *Pool) Call(serviceMethod string, args interface{}, + reply interface{}) error { + + client, err := p.get() + if err != nil { + return err + } + if err = client.Call(serviceMethod, args, reply); err != nil { + return err } -} -func (p *Pool) Put(client *rpc.Client) error { select { case p.clients <- client: return nil @@ -56,9 +76,18 @@ func (p *Pool) Put(client *rpc.Client) error { } } +func (p *Pool) get() (*rpc.Client, error) { + select { + case client := <-p.clients: + return client, nil + default: + return p.dialNew() + } +} + func (p *Pool) Close() error { // We're not doing a clean exit here. A clean exit here would require - // mutex locks around conns; which seems unnecessary just to shut down - // the server. + // synchronization, which seems unnecessary for now. But, we should + // add one if required later. return nil } diff --git a/query/query.go b/query/query.go index b063d850c1df14a8a0845c59f039682562760c57..c5ea0bd21dc9e7757b5023425fd9d4332d8fd0f2 100644 --- a/query/query.go +++ b/query/query.go @@ -170,7 +170,7 @@ func postTraverse(g *SubGraph) (result map[uint64]interface{}, rerr error) { r.Init(g.result, ro) if q.UidsLength() != r.UidmatrixLength() { - glog.Fatal("Result uidmatrixlength: %v. Query uidslength: %v", + glog.Fatalf("Result uidmatrixlength: %v. Query uidslength: %v", r.UidmatrixLength(), q.UidsLength()) } if q.UidsLength() != r.ValuesLength() { diff --git a/query/query_test.go b/query/query_test.go index fde7afd4af9624ac37a0f5d9ca4e90c29e355a70..95b124ce242869b05bd62e4434cb530121f8f922 100644 --- a/query/query_test.go +++ b/query/query_test.go @@ -81,7 +81,7 @@ func checkSingleValue(t *testing.T, child *SubGraph, } if ul.UidsLength() != 0 { - t.Error("Expected uids length 0. Got: %v", ul.UidsLength()) + t.Errorf("Expected uids length 0. Got: %v", ul.UidsLength()) } checkName(t, r, 0, value) } @@ -103,7 +103,7 @@ func TestNewGraph(t *testing.T) { t.Error(err) } - worker.Init(ps, nil, 0, 1) + worker.Init(ps, nil, nil, 0, 1) uo := flatbuffers.GetUOffsetT(sg.result) r := new(task.Result) @@ -134,7 +134,7 @@ func populateGraph(t *testing.T) (string, *store.Store) { ps := new(store.Store) ps.Init(dir) - worker.Init(ps, nil, 0, 1) + worker.Init(ps, nil, nil, 0, 1) clog := commit.NewLogger(dir, "mutations", 50<<20) clog.Init() diff --git a/server/main.go b/server/main.go index 499f689cf6010b987504f92ffee6d1112230253b..c89df478b3c33046fea7c3aea68bd7dce85becd3 100644 --- a/server/main.go +++ b/server/main.go @@ -22,6 +22,7 @@ import ( "io/ioutil" "net/http" "runtime" + "strings" "time" "github.com/Sirupsen/logrus" @@ -38,15 +39,15 @@ import ( var glog = x.Log("server") var postingDir = flag.String("postings", "", "Directory to store posting lists") -var xiduidDir = flag.String("xiduid", "", "XID UID posting lists directory") +var uidDir = flag.String("uids", "", "XID UID posting lists directory") var mutationDir = flag.String("mutations", "", "Directory to store mutations") var port = flag.String("port", "8080", "Port to run server on.") var numcpu = flag.Int("numCpu", runtime.NumCPU(), "Number of cores to be used by the process") var instanceIdx = flag.Uint64("instanceIdx", 0, "serves only entities whose Fingerprint % numInstance == instanceIdx.") -var numInstances = flag.Uint64("numInstances", 1, - "Total number of server instances") +var workers = flag.String("workers", "", + "Comma separated list of IP addresses of workers") func addCorsHeaders(w http.ResponseWriter) { w.Header().Set("Access-Control-Allow-Origin", "*") @@ -142,16 +143,22 @@ func main() { clog.Init() defer clog.Close() + addrs := strings.Split(*workers, ",") + posting.Init(clog) - if *instanceIdx == 0 { - xiduidStore := new(store.Store) - xiduidStore.Init(*xiduidDir) - defer xiduidStore.Close() - worker.Init(ps, xiduidStore, *instanceIdx, *numInstances) //Only server instance 0 will have xiduidStore - uid.Init(xiduidStore) + + if *instanceIdx != 0 { + worker.Init(ps, nil, addrs, *instanceIdx, len(addrs)) + uid.Init(nil) } else { - worker.Init(ps, nil, *instanceIdx, *numInstances) + uidStore := new(store.Store) + uidStore.Init(*uidDir) + defer uidStore.Close() + // Only server instance 0 will have uidStore + worker.Init(ps, uidStore, addrs, *instanceIdx, len(addrs)) + uid.Init(uidStore) } + worker.Connect() http.HandleFunc("/query", queryHandler) diff --git a/server/main_test.go b/server/main_test.go index 962256cdcbc3c1e314d34d0fd9f7272894b98f64..07e1642fd5593f4d236833f68f17a097ef2c1829 100644 --- a/server/main_test.go +++ b/server/main_test.go @@ -62,7 +62,7 @@ func prepare() (dir1, dir2 string, ps *store.Store, clog *commit.Logger, rerr er clog.Init() posting.Init(clog) - worker.Init(ps, nil, 0, 1) + worker.Init(ps, nil, nil, 0, 1) uid.Init(ps) loader.Init(ps, ps) diff --git a/worker/worker.go b/worker/worker.go index a6fd374aba0bb657cd922cede997b1bd65628fea..200fda61a6a3a945fd518aa7d5265fac12a1a6ed 100644 --- a/worker/worker.go +++ b/worker/worker.go @@ -5,7 +5,6 @@ import ( "io" "net" "net/rpc" - "strings" "github.com/dgraph-io/dgraph/conn" "github.com/dgraph-io/dgraph/posting" @@ -16,22 +15,22 @@ import ( "github.com/google/flatbuffers/go" ) -var workers = flag.String("workers", "", - "Comma separated list of IP addresses of workers") var workerPort = flag.String("workerport", ":12345", "Port used by worker for internal communication.") var glog = x.Log("worker") -var dataStore, xiduidStore *store.Store +var dataStore, uidStore *store.Store var pools []*conn.Pool -var addrs = strings.Split(*workers, ",") var numInstances, instanceIdx uint64 -func Init(ps, xuStore *store.Store, idx, numInst uint64) { +var addrs []string + +func Init(ps, uStore *store.Store, workerList []string, idx, numInst uint64) { dataStore = ps - xiduidStore = xuStore - numInstances = numInst + uidStore = xuStore + addrs = workerList instanceIdx = idx + numInstances = numInst } func Connect() { @@ -43,21 +42,15 @@ func Connect() { glog.Fatal(err) } - addrs := strings.Split(*workers, ",") - var pools []*conn.Pool for _, addr := range addrs { if len(addr) == 0 { continue } pool := conn.NewPool(addr, 5) - client, err := pool.Get() - if err != nil { - glog.Fatal(err) - } query := new(conn.Query) query.Data = []byte("hello") reply := new(conn.Reply) - if err = client.Call("Worker.Hello", query, reply); err != nil { + if err := pool.Call("Worker.Hello", query, reply); err != nil { glog.WithField("call", "Worker.Hello").Fatal(err) } glog.WithField("reply", string(reply.Data)).WithField("addr", addr). diff --git a/worker/worker_test.go b/worker/worker_test.go index 809eb05c876431c5533e50a30f0c200d45cbadfb..8730617e02a7b571142215b255697d2d2b1eb773 100644 --- a/worker/worker_test.go +++ b/worker/worker_test.go @@ -58,7 +58,7 @@ func TestProcessTask(t *testing.T) { defer clog.Close() posting.Init(clog) - Init(ps, nil, 0, 1) + Init(ps, nil, nil, 0, 1) edge := x.DirectedEdge{ ValueId: 23,