diff --git a/rdf/parse.go b/rdf/parse.go index 789553d47b3e1b6fa6e9f1d0befe6a3017efea42..0a8c8f6a2651366e161032debe2e5c8e30b6af40 100644 --- a/rdf/parse.go +++ b/rdf/parse.go @@ -70,6 +70,45 @@ func (nq NQuad) ToEdge(instanceIdx, return result, nil } +func toUid(xid string, xidToUid map[string]uint64) (uid uint64, rerr error) { + id, present := xidToUid[xid] + if present { + return id, nil + } + + if !strings.HasPrefix(xid, "_uid_:") { + return 0, fmt.Errorf("Unable to find xid: %v", xid) + } + return strconv.ParseUint(xid[6:], 0, 64) +} + +func (nq NQuad) ToEdgeUsing( + xidToUid map[string]uint64) (result x.DirectedEdge, rerr error) { + uid, err := toUid(nq.Subject, xidToUid) + if err != nil { + return result, err + } + result.Entity = uid + + if len(nq.ObjectId) == 0 { + result.Value = nq.ObjectValue + } else { + uid, err = toUid(nq.ObjectId, xidToUid) + if err != nil { + return result, err + } + result.ValueId = uid + } + if len(nq.Language) > 0 { + result.Attribute = nq.Predicate + "." + nq.Language + } else { + result.Attribute = nq.Predicate + } + result.Source = nq.Label + result.Timestamp = time.Now() + return result, nil +} + func stripBracketsIfPresent(val string) string { if val[0] != '<' { return val diff --git a/server/main.go b/server/main.go index f71564f5ed86636ee51f05eea320a356d1ff0a58..2c857348e0e99f9fffd9ab054ef6c7899f193a42 100644 --- a/server/main.go +++ b/server/main.go @@ -64,16 +64,54 @@ func addCorsHeaders(w http.ResponseWriter) { func mutationHandler(mu *gql.Mutation) error { r := strings.NewReader(mu.Set) scanner := bufio.NewScanner(r) + var nquads []rdf.NQuad for scanner.Scan() { ln := strings.Trim(scanner.Text(), " \t") if len(ln) == 0 { continue } - _, err := rdf.Parse(ln) + nq, err := rdf.Parse(ln) if err != nil { glog.WithError(err).Error("While parsing RDF.") return err } + nquads = append(nquads, nq) + } + + xidToUid := make(map[string]uint64) + for _, nq := range nquads { + if !strings.HasPrefix("_uid_:", nq.Subject) { + xidToUid[nq.Subject] = 0 + } + if !strings.HasPrefix("_uid_:", nq.ObjectId) { + xidToUid[nq.ObjectId] = 0 + } + } + if err := worker.GetOrAssignUidsOverNetwork(&xidToUid); err != nil { + return err + } + + var edges []x.DirectedEdge + for _, nq := range nquads { + edge, err := nq.ToEdgeUsing(xidToUid) + if err != nil { + glog.WithField("nquad", nq).WithError(err). + Error("While converting to edge") + return err + } + edges = append(edges, edge) + } + + left, err := worker.MutateOverNetwork(edges) + if err != nil { + return err + } + if len(left) > 0 { + glog.WithField("left", len(left)).Error("Some edges couldn't be applied") + for _, e := range left { + glog.WithField("edge", e).Debug("Unable to apply mutation") + } + return fmt.Errorf("Unapplied mutations") } return nil } diff --git a/worker/assign.go b/worker/assign.go new file mode 100644 index 0000000000000000000000000000000000000000..52c4a721fd87be768a1ff26ff8481a4294571938 --- /dev/null +++ b/worker/assign.go @@ -0,0 +1,115 @@ +package worker + +import ( + "sync" + + "github.com/dgraph-io/dgraph/conn" + "github.com/dgraph-io/dgraph/task" + "github.com/dgraph-io/dgraph/uid" + "github.com/google/flatbuffers/go" +) + +func createXidListBuffer(xids map[string]uint64) []byte { + b := flatbuffers.NewBuilder(0) + var offsets []flatbuffers.UOffsetT + for xid := range xids { + uo := b.CreateString(xid) + offsets = append(offsets, uo) + } + + task.XidListStartXidsVector(b, len(offsets)) + for _, uo := range offsets { + b.PrependUOffsetT(uo) + } + ve := b.EndVector(len(offsets)) + + task.XidListStart(b) + task.XidListAddXids(b, ve) + lo := task.XidListEnd(b) + b.Finish(lo) + return b.Bytes[b.Head():] +} + +func getOrAssignUids( + xidList *task.XidList) (uidList []byte, rerr error) { + + wg := new(sync.WaitGroup) + uids := make([]uint64, xidList.XidsLength()) + che := make(chan error, xidList.XidsLength()) + for i := 0; i < xidList.XidsLength(); i++ { + wg.Add(1) + xid := string(xidList.Xids(i)) + + go func() { + defer wg.Done() + u, err := uid.GetOrAssign(xid, 0, 1) + if err != nil { + che <- err + return + } + uids[i] = u + }() + } + wg.Wait() + close(che) + for err := range che { + glog.WithError(err).Error("Encountered errors while getOrAssignUids") + return uidList, err + } + + b := flatbuffers.NewBuilder(0) + task.UidListStartUidsVector(b, xidList.XidsLength()) + for i := len(uids) - 1; i >= 0; i-- { + b.PrependUint64(uids[i]) + } + ve := b.EndVector(xidList.XidsLength()) + + task.UidListStart(b) + task.UidListAddUids(b, ve) + uend := task.UidListEnd(b) + b.Finish(uend) + return b.Bytes[b.Head():], nil +} + +func GetOrAssignUidsOverNetwork(xidToUid *map[string]uint64) (rerr error) { + query := new(conn.Query) + query.Data = createXidListBuffer(*xidToUid) + uo := flatbuffers.GetUOffsetT(query.Data) + xidList := new(task.XidList) + xidList.Init(query.Data, uo) + + reply := new(conn.Reply) + if instanceIdx == 0 { + uo := flatbuffers.GetUOffsetT(query.Data) + xidList := new(task.XidList) + xidList.Init(query.Data, uo) + + reply.Data, rerr = getOrAssignUids(xidList) + if rerr != nil { + return rerr + } + } else { + pool := pools[0] + if err := pool.Call("Worker.GetOrAssign", query, reply); err != nil { + glog.WithField("method", "GetOrAssign").WithError(err). + Error("While getting uids") + return err + } + } + + uidList := new(task.UidList) + uo = flatbuffers.GetUOffsetT(reply.Data) + uidList.Init(reply.Data, uo) + + if xidList.XidsLength() != uidList.UidsLength() { + glog.WithField("num_xids", xidList.XidsLength()). + WithField("num_uids", uidList.UidsLength()). + Fatal("Num xids don't match num uids") + } + for i := 0; i < xidList.XidsLength(); i++ { + xid := string(xidList.Xids(i)) + uid := uidList.Uids(i) + (*xidToUid)[xid] = uid + } + return nil +} diff --git a/worker/assign_test.go b/worker/assign_test.go new file mode 100644 index 0000000000000000000000000000000000000000..3905b6398d0601b66c66f3017cbcfe67b21f175e --- /dev/null +++ b/worker/assign_test.go @@ -0,0 +1,36 @@ +package worker + +import ( + "testing" + + "github.com/dgraph-io/dgraph/task" + "github.com/google/flatbuffers/go" +) + +func TestXidListBuffer(t *testing.T) { + xids := map[string]uint64{ + "b.0453": 0, + "d.z1sz": 0, + "e.abcd": 0, + } + + buf := createXidListBuffer(xids) + + uo := flatbuffers.GetUOffsetT(buf) + xl := new(task.XidList) + xl.Init(buf, uo) + + if xl.XidsLength() != len(xids) { + t.Errorf("Expected: %v. Got: %v", len(xids), xl.XidsLength()) + } + for i := 0; i < xl.XidsLength(); i++ { + xid := string(xl.Xids(i)) + t.Logf("Found: %v", xid) + xids[xid] = 7 + } + for xid, val := range xids { + if val != 7 { + t.Errorf("Expected xid: %v to be part of the buffer.", xid) + } + } +} diff --git a/worker/worker.go b/worker/worker.go index c09afe8abc263ca7a9dadd48570cb63dffffeb55..3fd932e8b4847664456b14a843f9f4376fe97c0b 100644 --- a/worker/worker.go +++ b/worker/worker.go @@ -91,6 +91,22 @@ func (w *Worker) Hello(query *conn.Query, reply *conn.Reply) error { return nil } +func (w *Worker) GetOrAssign(query *conn.Query, + reply *conn.Reply) (rerr error) { + + uo := flatbuffers.GetUOffsetT(query.Data) + xids := new(task.XidList) + xids.Init(query.Data, uo) + + if instanceIdx != 0 { + glog.WithField("instanceIdx", instanceIdx). + WithField("GetOrAssign", true). + Fatal("We shouldn't be receiving this request.") + } + reply.Data, rerr = getOrAssignUids(xids) + return +} + func (w *Worker) Mutate(query *conn.Query, reply *conn.Reply) (rerr error) { m := new(Mutations) if err := m.Decode(query.Data); err != nil {