diff --git a/conn/pool.go b/conn/pool.go index b8aa13057840b02c2db6687e4a04be52194c1d7e..40e00e8bb458dd6a61068263a2cfcfabcfd7a93f 100644 --- a/conn/pool.go +++ b/conn/pool.go @@ -13,12 +13,12 @@ var glog = x.Log("conn") type Pool struct { clients chan *rpc.Client - addr string + Addr string } func NewPool(addr string, maxCap int) *Pool { p := new(Pool) - p.addr = addr + p.Addr = addr p.clients = make(chan *rpc.Client, maxCap) client, err := p.dialNew() if err != nil { @@ -36,7 +36,7 @@ func (p *Pool) dialNew() (*rpc.Client, error) { var nconn net.Conn var err error for i := 0; i < 10; i++ { - nconn, err = d.Dial("tcp", p.addr) + nconn, err = d.Dial("tcp", p.Addr) if err == nil { break } @@ -44,7 +44,7 @@ func (p *Pool) dialNew() (*rpc.Client, error) { break } - glog.WithField("error", err).WithField("addr", p.addr). + glog.WithField("error", err).WithField("addr", p.Addr). Info("Retrying connection...") time.Sleep(10 * time.Second) } diff --git a/query/query.go b/query/query.go index c5ea0bd21dc9e7757b5023425fd9d4332d8fd0f2..e564c87355b2423fe08acc5db0a26d1d4b5ac2bb 100644 --- a/query/query.go +++ b/query/query.go @@ -406,7 +406,7 @@ func ProcessGraph(sg *SubGraph, rch chan error) { var err error if len(sg.query) > 0 && sg.Attr != "_root_" { // This task execution would go over the wire in later versions. - sg.result, err = worker.ProcessTask(sg.query) + sg.result, err = worker.ProcessTaskOverNetwork(sg.query) if err != nil { x.Err(glog, err).Error("While processing task.") rch <- err diff --git a/query/query_test.go b/query/query_test.go index 397bcaa04f13ec318ff8c2a99c9bbbd6979ace71..7589cdf349da3b416a1be080d2f403fba689142c 100644 --- a/query/query_test.go +++ b/query/query_test.go @@ -103,7 +103,7 @@ func TestNewGraph(t *testing.T) { t.Error(err) } - worker.Init(ps, nil, nil) + worker.Init(ps, nil, 0, 1) uo := flatbuffers.GetUOffsetT(sg.result) r := new(task.Result) @@ -134,7 +134,7 @@ func populateGraph(t *testing.T) (string, *store.Store) { ps := new(store.Store) ps.Init(dir) - worker.Init(ps, nil, nil) + worker.Init(ps, nil, 0, 1) clog := commit.NewLogger(dir, "mutations", 50<<20) clog.Init() diff --git a/server/main.go b/server/main.go index 1200e7669f0dda124f1531dd3c3afe018ed7a413..5882f182aad1f81e559e9131d7d61dff2d25b6eb 100644 --- a/server/main.go +++ b/server/main.go @@ -144,21 +144,23 @@ func main() { defer clog.Close() addrs := strings.Split(*workers, ",") + lenAddr := uint64(len(addrs)) posting.Init(clog) + if *instanceIdx != 0 { - worker.Init(ps, nil, addrs) + worker.Init(ps, nil, *instanceIdx, lenAddr) uid.Init(nil) } else { uidStore := new(store.Store) uidStore.Init(*uidDir) defer uidStore.Close() // Only server instance 0 will have uidStore - worker.Init(ps, uidStore, addrs) + worker.Init(ps, uidStore, *instanceIdx, lenAddr) uid.Init(uidStore) } - worker.Connect() + worker.Connect(addrs) http.HandleFunc("/query", queryHandler) glog.WithField("port", *port).Info("Listening for requests...") diff --git a/server/main_test.go b/server/main_test.go index 78c20da95f409349df5aed4163286ba06bb8b159..962256cdcbc3c1e314d34d0fd9f7272894b98f64 100644 --- a/server/main_test.go +++ b/server/main_test.go @@ -62,7 +62,7 @@ func prepare() (dir1, dir2 string, ps *store.Store, clog *commit.Logger, rerr er clog.Init() posting.Init(clog) - worker.Init(ps, nil, nil) + worker.Init(ps, nil, 0, 1) uid.Init(ps) loader.Init(ps, ps) diff --git a/worker/worker.go b/worker/worker.go index 20ee65a8d3401d56a1196c43e6c2af77547ea4ab..1ed369419aa937eee63558429715b4f00c7b8593 100644 --- a/worker/worker.go +++ b/worker/worker.go @@ -11,6 +11,7 @@ import ( "github.com/dgraph-io/dgraph/store" "github.com/dgraph-io/dgraph/task" "github.com/dgraph-io/dgraph/x" + "github.com/dgryski/go-farm" "github.com/google/flatbuffers/go" ) @@ -18,17 +19,18 @@ var workerPort = flag.String("workerport", ":12345", "Port used by worker for internal communication.") var glog = x.Log("worker") -var dataStore, xiduidStore *store.Store +var dataStore, uidStore *store.Store var pools []*conn.Pool -var addrs []string +var numInstances, instanceIdx uint64 -func Init(ps, xuStore *store.Store, workerList []string) { +func Init(ps, uStore *store.Store, idx, numInst uint64) { dataStore = ps - xiduidStore = xuStore - addrs = workerList + uidStore = uStore + instanceIdx = idx + numInstances = numInst } -func Connect() { +func Connect(workerList []string) { w := new(Worker) if err := rpc.Register(w); err != nil { glog.Fatal(err) @@ -36,8 +38,13 @@ func Connect() { if err := runServer(*workerPort); err != nil { glog.Fatal(err) } + if uint64(len(workerList)) != numInstances { + glog.WithField("len(list)", len(workerList)). + WithField("numInstances", numInstances). + Fatalf("Wrong number of instances in workerList") + } - for _, addr := range addrs { + for _, addr := range workerList { if len(addr) == 0 { continue } @@ -56,16 +63,49 @@ func Connect() { glog.Info("Server started. Clients connected.") } +func ProcessTaskOverNetwork(qu []byte) (result []byte, rerr error) { + uo := flatbuffers.GetUOffsetT(qu) + q := new(task.Query) + q.Init(qu, uo) + + attr := string(q.Attr()) + idx := farm.Fingerprint64([]byte(attr)) % numInstances + + var runHere bool + if attr == "_xid_" || attr == "_uid_" { + idx = 0 + runHere = (instanceIdx == 0) + } else { + runHere = (instanceIdx == idx) + } + + if runHere { + return ProcessTask(qu) + } + + pool := pools[idx] + addr := pool.Addr + query := new(conn.Query) + query.Data = qu + reply := new(conn.Reply) + if err := pool.Call("Worker.ServeTask", query, reply); err != nil { + glog.WithField("call", "Worker.ServeTask").Fatal(err) + } + glog.WithField("reply", string(reply.Data)).WithField("addr", addr). + Info("Got reply from server") + return reply.Data, nil +} + func ProcessTask(query []byte) (result []byte, rerr error) { uo := flatbuffers.GetUOffsetT(query) q := new(task.Query) q.Init(query, uo) + attr := string(q.Attr()) b := flatbuffers.NewBuilder(0) voffsets := make([]flatbuffers.UOffsetT, q.UidsLength()) uoffsets := make([]flatbuffers.UOffsetT, q.UidsLength()) - attr := string(q.Attr()) for i := 0; i < q.UidsLength(); i++ { uid := q.Uids(i) key := posting.Key(uid, attr) @@ -133,6 +173,22 @@ func (w *Worker) Hello(query *conn.Query, reply *conn.Reply) error { return nil } +func (w *Worker) ServeTask(query *conn.Query, reply *conn.Reply) (rerr error) { + uo := flatbuffers.GetUOffsetT(query.Data) + q := new(task.Query) + q.Init(query.Data, uo) + attr := string(q.Attr()) + + if farm.Fingerprint64([]byte(attr))%numInstances == instanceIdx { + reply.Data, rerr = ProcessTask(query.Data) + } else { + glog.WithField("attribute", attr). + WithField("instanceIdx", instanceIdx). + Fatalf("Request sent to wrong server") + } + return rerr +} + func serveRequests(irwc io.ReadWriteCloser) { for { sc := &conn.ServerCodec{ diff --git a/worker/worker_test.go b/worker/worker_test.go index 35ed56842824af45d10239d741e4f8e9335e495d..809eb05c876431c5533e50a30f0c200d45cbadfb 100644 --- a/worker/worker_test.go +++ b/worker/worker_test.go @@ -58,7 +58,7 @@ func TestProcessTask(t *testing.T) { defer clog.Close() posting.Init(clog) - Init(ps, nil, nil) + Init(ps, nil, 0, 1) edge := x.DirectedEdge{ ValueId: 23,