Skip to content
Snippets Groups Projects
worker.go 4.78 KiB
Newer Older
  • Learn to ignore specific revisions
  • package worker
    
    	"github.com/dgraph-io/dgraph/posting"
    
    Ashwin's avatar
    Ashwin committed
    	"github.com/dgraph-io/dgraph/store"
    
    	"github.com/dgraph-io/dgraph/task"
    
    	"github.com/google/flatbuffers/go"
    
    var workerPort = flag.String("workerport", ":12345",
    	"Port used by worker for internal communication.")
    
    var glog = x.Log("worker")
    
    var dataStore, xiduidStore *store.Store
    
    var addrs []string
    
    func Init(ps, xuStore *store.Store, workerList []string, idx uint64) {
    
    	dataStore = ps
    
    	xiduidStore = xuStore
    
    	addrs = workerList
    
    func Connect() {
    	w := new(Worker)
    	if err := rpc.Register(w); err != nil {
    		glog.Fatal(err)
    	}
    	if err := runServer(*workerPort); err != nil {
    		glog.Fatal(err)
    	}
    
    	for _, addr := range addrs {
    		if len(addr) == 0 {
    			continue
    		}
    		pool := conn.NewPool(addr, 5)
    		query := new(conn.Query)
    		query.Data = []byte("hello")
    		reply := new(conn.Reply)
    
    		if err := pool.Call("Worker.Hello", query, reply); err != nil {
    
    			glog.WithField("call", "Worker.Hello").Fatal(err)
    		}
    		glog.WithField("reply", string(reply.Data)).WithField("addr", addr).
    			Info("Got reply from server")
    		pools = append(pools, pool)
    	}
    
    	glog.Info("Server started. Clients connected.")
    }
    
    
    func ProcessTask(query []byte) (result []byte, rerr error) {
    
    	uo := flatbuffers.GetUOffsetT(query)
    	q := new(task.Query)
    	q.Init(query, uo)
    
    	b := flatbuffers.NewBuilder(0)
    
    	voffsets := make([]flatbuffers.UOffsetT, q.UidsLength())
    	uoffsets := make([]flatbuffers.UOffsetT, q.UidsLength())
    
    
    	attr := string(q.Attr())
    	for i := 0; i < q.UidsLength(); i++ {
    		uid := q.Uids(i)
    
    		key := posting.Key(uid, attr)
    
    		pl := posting.GetOrCreate(key, dataStore)
    
    
    		var valoffset flatbuffers.UOffsetT
    		if val, err := pl.Value(); err != nil {
    
    			valoffset = b.CreateByteVector(x.Nilbyte)
    
    		voffsets[i] = task.ValueEnd(b)
    
    		ulist := pl.GetUids()
    
    		uoffsets[i] = x.UidlistOffset(b, ulist)
    
    	}
    	task.ResultStartValuesVector(b, len(voffsets))
    	for i := len(voffsets) - 1; i >= 0; i-- {
    		b.PrependUOffsetT(voffsets[i])
    	}
    	valuesVent := b.EndVector(len(voffsets))
    
    
    	task.ResultStartUidmatrixVector(b, len(uoffsets))
    	for i := len(uoffsets) - 1; i >= 0; i-- {
    		b.PrependUOffsetT(uoffsets[i])
    
    	matrixVent := b.EndVector(len(uoffsets))
    
    
    	task.ResultStart(b)
    	task.ResultAddValues(b, valuesVent)
    
    	task.ResultAddUidmatrix(b, matrixVent)
    
    	rend := task.ResultEnd(b)
    	b.Finish(rend)
    	return b.Bytes[b.Head():], nil
    }
    
    
    Manish R Jain's avatar
    Manish R Jain committed
    func NewQuery(attr string, uids []uint64) []byte {
    	b := flatbuffers.NewBuilder(0)
    	task.QueryStartUidsVector(b, len(uids))
    	for i := len(uids) - 1; i >= 0; i-- {
    		b.PrependUint64(uids[i])
    	}
    	vend := b.EndVector(len(uids))
    
    	ao := b.CreateString(attr)
    	task.QueryStart(b)
    	task.QueryAddAttr(b, ao)
    	task.QueryAddUids(b, vend)
    	qend := task.QueryEnd(b)
    	b.Finish(qend)
    	return b.Bytes[b.Head():]
    }
    
    
    type Worker struct {
    }
    
    func (w *Worker) Hello(query *conn.Query, reply *conn.Reply) error {
    	if string(query.Data) == "hello" {
    		reply.Data = []byte("Oh hello there!")
    	} else {
    		reply.Data = []byte("Hey stranger!")
    	}
    	return nil
    }
    
    
    func (w *Worker) Mutate(query *conn.Query, reply *conn.Reply) (rerr error) {
    	m := new(Mutations)
    	if err := m.Decode(query.Data); err != nil {
    		return err
    	}
    
    	left := new(Mutations)
    	// For now, assume it's all only Set instructions.
    	for _, edge := range m.Set {
    		if farm.Fingerprint64(
    			[]byte(edge.Attribute))%uint64(len(addrs)) != instanceIdx {
    
    			glog.WithField("instanceIdx", instanceIdx).
    				WithField("attr", edge.Attribute).
    				Info("Predicate fingerprint doesn't match instanceIdx")
    			return fmt.Errorf("predicate fingerprint doesn't match this instance.")
    		}
    
    		key := posting.Key(edge.Entity, edge.Attribute)
    		plist := posting.GetOrCreate(key, dataStore)
    		if err := plist.AddMutation(edge, posting.Set); err != nil {
    			left.Set = append(left.Set, edge)
    			glog.WithError(err).WithField("edge", edge).Error("While adding mutation.")
    			continue
    		}
    	}
    	reply.Data, rerr = left.Encode()
    	return
    }
    
    
    func serveRequests(irwc io.ReadWriteCloser) {
    	for {
    		sc := &conn.ServerCodec{
    			Rwc: irwc,
    		}
    		rpc.ServeRequest(sc)
    	}
    }
    
    func runServer(address string) error {
    	ln, err := net.Listen("tcp", address)
    	if err != nil {
    		glog.Fatalf("While running server: %v", err)
    		return err
    	}
    	glog.WithField("address", ln.Addr()).Info("Worker listening")
    
    	go func() {
    		for {
    			cxn, err := ln.Accept()
    			if err != nil {
    				glog.Fatalf("listen(%q): %s\n", address, err)
    				return
    			}
    			glog.WithField("local", cxn.LocalAddr()).
    				WithField("remote", cxn.RemoteAddr()).
    				Debug("Worker accepted connection")
    			go serveRequests(cxn)
    		}
    	}()
    	return nil
    }