diff --git a/worker/assign.go b/worker/assign.go new file mode 100644 index 0000000000000000000000000000000000000000..0b019dc4723e464461abf8ba674653a6d027ef99 --- /dev/null +++ b/worker/assign.go @@ -0,0 +1,71 @@ +package worker + +import ( + "sync" + + "github.com/dgraph-io/dgraph/task" + "github.com/dgraph-io/dgraph/uid" + "github.com/google/flatbuffers/go" +) + +func createXidListBuffer(xids map[string]bool) []byte { + b := flatbuffers.NewBuilder(0) + var offsets []flatbuffers.UOffsetT + for xid := range xids { + uo := b.CreateString(xid) + offsets = append(offsets, uo) + } + + task.XidListStartXidsVector(b, len(offsets)) + for _, uo := range offsets { + b.PrependUOffsetT(uo) + } + ve := b.EndVector(len(offsets)) + + task.XidListStart(b) + task.XidListAddXids(b, ve) + lo := task.XidListEnd(b) + b.Finish(lo) + return b.Bytes[b.Head():] +} + +func getOrAssignUids( + xidList *task.XidList) (uidList []byte, rerr error) { + + wg := new(sync.WaitGroup) + uids := make([]uint64, xidList.XidsLength()) + che := make(chan error, xidList.XidsLength()) + for i := 0; i < xidList.XidsLength(); i++ { + wg.Add(1) + xid := string(xidList.Xids(i)) + + go func() { + defer wg.Done() + u, err := uid.GetOrAssign(xid, 0, 1) + if err != nil { + che <- err + return + } + uids[i] = u + }() + } + wg.Wait() + close(che) + for err := range che { + glog.WithError(err).Error("Encountered errors while getOrAssignUids") + return uidList, err + } + + b := flatbuffers.NewBuilder(0) + task.UidListStartUidsVector(b, xidList.XidsLength()) + for i := len(uids) - 1; i >= 0; i-- { + b.PrependUint64(uids[i]) + } + ve := b.EndVector(xidList.XidsLength()) + + task.UidListStart(b) + task.UidListAddUids(b, ve) + uend := task.UidListEnd(b) + b.Finish(uend) + return b.Bytes[b.Head():], nil +} diff --git a/worker/assign_test.go b/worker/assign_test.go new file mode 100644 index 0000000000000000000000000000000000000000..952e0017eacde724511bcba418e8b7f3a4ccb7de --- /dev/null +++ b/worker/assign_test.go @@ -0,0 +1,36 @@ +package worker + +import ( + "testing" + + "github.com/dgraph-io/dgraph/task" + "github.com/google/flatbuffers/go" +) + +func TestXidListBuffer(t *testing.T) { + xids := map[string]bool{ + "b.0453": true, + "d.z1sz": true, + "e.abcd": true, + } + + buf := createXidListBuffer(xids) + + uo := flatbuffers.GetUOffsetT(buf) + xl := new(task.XidList) + xl.Init(buf, uo) + + if xl.XidsLength() != len(xids) { + t.Errorf("Expected: %v. Got: %v", len(xids), xl.XidsLength()) + } + for i := 0; i < xl.XidsLength(); i++ { + xid := string(xl.Xids(i)) + t.Logf("Found: %v", xid) + xids[xid] = false + } + for xid, untouched := range xids { + if untouched { + t.Errorf("Expected xid: %v to be part of the buffer.", xid) + } + } +} diff --git a/worker/worker.go b/worker/worker.go index c09afe8abc263ca7a9dadd48570cb63dffffeb55..3fd932e8b4847664456b14a843f9f4376fe97c0b 100644 --- a/worker/worker.go +++ b/worker/worker.go @@ -91,6 +91,22 @@ func (w *Worker) Hello(query *conn.Query, reply *conn.Reply) error { return nil } +func (w *Worker) GetOrAssign(query *conn.Query, + reply *conn.Reply) (rerr error) { + + uo := flatbuffers.GetUOffsetT(query.Data) + xids := new(task.XidList) + xids.Init(query.Data, uo) + + if instanceIdx != 0 { + glog.WithField("instanceIdx", instanceIdx). + WithField("GetOrAssign", true). + Fatal("We shouldn't be receiving this request.") + } + reply.Data, rerr = getOrAssignUids(xids) + return +} + func (w *Worker) Mutate(query *conn.Query, reply *conn.Reply) (rerr error) { m := new(Mutations) if err := m.Decode(query.Data); err != nil { diff --git a/worker/worker_test.go b/worker/worker_test.go index 809eb05c876431c5533e50a30f0c200d45cbadfb..e16a6f81ea57b27a654cb30e71d830fd38978ccf 100644 --- a/worker/worker_test.go +++ b/worker/worker_test.go @@ -83,7 +83,7 @@ func TestProcessTask(t *testing.T) { addEdge(t, edge, posting.GetOrCreate(posting.Key(12, "friend"), ps)) query := NewQuery("friend", []uint64{10, 11, 12}) - result, err := ProcessTask(query) + result, err := processTask(query) if err != nil { t.Error(err) }