diff --git a/posting/worker.go b/posting/worker.go index f7c1957efd7d812788d78be2fccb0b1a1819abe1..04392d9b1a09f4124975139cb96deae4372ade5c 100644 --- a/posting/worker.go +++ b/posting/worker.go @@ -29,6 +29,7 @@ func (h *elemHeap) Pop() interface{} { } func addUids(b *flatbuffers.Builder, sorted []uint64) flatbuffers.UOffsetT { + // Invert the sorted uids to maintain same order in flatbuffers. task.ResultStartUidsVector(b, len(sorted)) for i := len(sorted) - 1; i >= 0; i-- { b.PrependUint64(sorted[i]) @@ -114,6 +115,23 @@ func ProcessQuery(query []byte) (result []byte, rerr error) { return b.Bytes[b.Head():], nil } +func NewQuery(attr string, uids []uint64) []byte { + b := flatbuffers.NewBuilder(0) + task.QueryStartUidsVector(b, len(uids)) + for i := len(uids) - 1; i >= 0; i-- { + b.PrependUint64(uids[i]) + } + vend := b.EndVector(len(uids)) + + ao := b.CreateString(attr) + task.QueryStart(b) + task.QueryAddAttr(b, ao) + task.QueryAddUids(b, vend) + qend := task.QueryEnd(b) + b.Finish(qend) + return b.Bytes[b.Head():] +} + var nilbyte []byte func init() { diff --git a/posting/worker_test.go b/posting/worker_test.go index 5bfc49e6568b2cb7141fb991a7067dce615697c1..99fb2391bccc168b1d39578e53aba3d72db89bc2 100644 --- a/posting/worker_test.go +++ b/posting/worker_test.go @@ -2,7 +2,15 @@ package posting import ( "container/heap" + "os" "testing" + "time" + + "github.com/Sirupsen/logrus" + "github.com/google/flatbuffers/go" + "github.com/manishrjain/dgraph/store" + "github.com/manishrjain/dgraph/task" + "github.com/manishrjain/dgraph/x" ) func TestPush(t *testing.T) { @@ -53,3 +61,96 @@ func TestPush(t *testing.T) { t.Errorf("Expected len 0. Found: %v, values: %+v", h.Len(), h) } } + +func addTriple(t *testing.T, triple x.Triple, l *List) { + if err := l.AddMutation(triple, Set); err != nil { + t.Error(err) + } +} + +func TestProcessQuery(t *testing.T) { + logrus.SetLevel(logrus.DebugLevel) + + pdir := NewStore(t) + defer os.RemoveAll(pdir) + ps := new(store.Store) + ps.Init(pdir) + + mdir := NewStore(t) + defer os.RemoveAll(mdir) + ms := new(store.Store) + ms.Init(mdir) + Init(ps, ms) + + triple := x.Triple{ + ValueId: 23, + Source: "author0", + Timestamp: time.Now(), + } + addTriple(t, triple, Get(Key(10, "friend"))) + addTriple(t, triple, Get(Key(11, "friend"))) + addTriple(t, triple, Get(Key(12, "friend"))) + + triple.ValueId = 25 + addTriple(t, triple, Get(Key(12, "friend"))) + + triple.ValueId = 26 + addTriple(t, triple, Get(Key(12, "friend"))) + + triple.ValueId = 31 + addTriple(t, triple, Get(Key(10, "friend"))) + addTriple(t, triple, Get(Key(12, "friend"))) + + triple.Value = "photon" + addTriple(t, triple, Get(Key(12, "friend"))) + + query := NewQuery("friend", []uint64{10, 11, 12}) + result, err := ProcessQuery(query) + if err != nil { + t.Error(err) + } + + ro := flatbuffers.GetUOffsetT(result) + r := new(task.Result) + r.Init(result, ro) + + if r.UidsLength() != 4 { + t.Errorf("Expected 4. Got uids length: %v", r.UidsLength()) + } + if r.Uids(0) != 23 { + t.Errorf("Expected 23. Got: %v", r.Uids(0)) + } + if r.Uids(1) != 25 { + t.Errorf("Expected 25. Got: %v", r.Uids(0)) + } + if r.Uids(2) != 26 { + t.Errorf("Expected 26. Got: %v", r.Uids(0)) + } + if r.Uids(3) != 31 { + t.Errorf("Expected 31. Got: %v", r.Uids(0)) + } + if r.ValuesLength() != 3 { + t.Errorf("Expected 3. Got values length: %v", r.ValuesLength()) + } + + var tval task.Value + if ok := r.Values(&tval, 0); !ok { + t.Errorf("Unable to retrieve value") + } + if tval.ValLength() != 1 || + tval.ValBytes()[0] != 0x00 { + t.Errorf("Invalid byte value at index 0") + } + + if ok := r.Values(&tval, 2); !ok { + t.Errorf("Unable to retrieve value") + } + + var v string + if err := ParseValue(&v, tval.ValBytes()); err != nil { + t.Error(err) + } + if v != "photon" { + t.Errorf("Expected photon. Got: %q", v) + } +}