diff --git a/conn/client.go b/conn/client.go new file mode 100644 index 0000000000000000000000000000000000000000..2b668dc05c65cd02f80f01529581d81369589696 --- /dev/null +++ b/conn/client.go @@ -0,0 +1,57 @@ +package conn + +import ( + "errors" + "fmt" + "io" + "log" + "net/rpc" +) + +type ClientCodec struct { + Rwc io.ReadWriteCloser + payloadLen int32 +} + +func (c *ClientCodec) WriteRequest(r *rpc.Request, body interface{}) error { + if body == nil { + return fmt.Errorf("Nil request body from client.") + } + + query := body.(*Query) + if err := writeHeader(c.Rwc, r.Seq, r.ServiceMethod, query.Data); err != nil { + return err + } + n, err := c.Rwc.Write(query.Data) + if n != len(query.Data) { + return errors.New("Unable to write payload.") + } + return err +} + +func (c *ClientCodec) ReadResponseHeader(r *rpc.Response) error { + if len(r.Error) > 0 { + log.Fatal("client got response error: " + r.Error) + } + if err := parseHeader(c.Rwc, &r.Seq, + &r.ServiceMethod, &c.payloadLen); err != nil { + return err + } + return nil +} + +func (c *ClientCodec) ReadResponseBody(body interface{}) error { + buf := make([]byte, c.payloadLen) + n, err := c.Rwc.Read(buf) + if n != int(c.payloadLen) { + return fmt.Errorf("ClientCodec expected: %d. Got: %d\n", c.payloadLen, n) + } + + reply := body.(*Reply) + reply.Data = buf + return err +} + +func (c *ClientCodec) Close() error { + return c.Rwc.Close() +} diff --git a/conn/codec.go b/conn/codec.go new file mode 100644 index 0000000000000000000000000000000000000000..e40864a02892aca7e242b0fca284f196a5975850 --- /dev/null +++ b/conn/codec.go @@ -0,0 +1,60 @@ +package conn + +import ( + "bytes" + "encoding/binary" + "fmt" + "io" + + "github.com/dgraph-io/dgraph/x" +) + +type Query struct { + Data []byte +} + +type Reply struct { + Data []byte +} + +func writeHeader(rwc io.ReadWriteCloser, seq uint64, + method string, data []byte) error { + + var bh bytes.Buffer + var rerr error + + x.SetError(&rerr, binary.Write(&bh, binary.LittleEndian, seq)) + x.SetError(&rerr, binary.Write(&bh, binary.LittleEndian, int32(len(method)))) + x.SetError(&rerr, binary.Write(&bh, binary.LittleEndian, int32(len(data)))) + _, err := bh.Write([]byte(method)) + x.SetError(&rerr, err) + if rerr != nil { + return rerr + } + _, err = rwc.Write(bh.Bytes()) + return err +} + +func parseHeader(rwc io.ReadWriteCloser, seq *uint64, + method *string, plen *int32) error { + + var err error + var sz int32 + x.SetError(&err, binary.Read(rwc, binary.LittleEndian, seq)) + x.SetError(&err, binary.Read(rwc, binary.LittleEndian, &sz)) + x.SetError(&err, binary.Read(rwc, binary.LittleEndian, plen)) + if err != nil { + return err + } + + buf := make([]byte, sz) + n, err := rwc.Read(buf) + if err != nil { + return err + } + if n != int(sz) { + return fmt.Errorf("Expected: %v. Got: %v\n", sz, n) + } + *method = string(buf) + return nil +} diff --git a/conn/pool.go b/conn/pool.go new file mode 100644 index 0000000000000000000000000000000000000000..7168c9cf3eb1f96dc89feea11a7c1c96d3cdbc3b --- /dev/null +++ b/conn/pool.go @@ -0,0 +1,64 @@ +package conn + +import ( + "net" + "net/rpc" + + "github.com/dgraph-io/dgraph/x" +) + +var glog = x.Log("conn") + +type Pool struct { + clients chan *rpc.Client + addr string +} + +func NewPool(addr string, maxCap int) *Pool { + p := new(Pool) + p.addr = addr + p.clients = make(chan *rpc.Client, maxCap) + client, err := p.dialNew() + if err != nil { + glog.Fatal(err) + return nil + } + p.clients <- client + return p +} + +func (p *Pool) dialNew() (*rpc.Client, error) { + nconn, err := net.Dial("tcp", p.addr) + if err != nil { + return nil, err + } + cc := &ClientCodec{ + Rwc: nconn, + } + return rpc.NewClientWithCodec(cc), nil +} + +func (p *Pool) Get() (*rpc.Client, error) { + select { + case client := <-p.clients: + return client, nil + default: + return p.dialNew() + } +} + +func (p *Pool) Put(client *rpc.Client) error { + select { + case p.clients <- client: + return nil + default: + return client.Close() + } +} + +func (p *Pool) Close() error { + // We're not doing a clean exit here. A clean exit here would require + // mutex locks around conns; which seems unnecessary just to shut down + // the server. + return nil +} diff --git a/conn/server.go b/conn/server.go new file mode 100644 index 0000000000000000000000000000000000000000..986783a5f49f396496fd4cb4569ecb5f967326ae --- /dev/null +++ b/conn/server.go @@ -0,0 +1,61 @@ +package conn + +import ( + "errors" + "io" + "log" + "net/rpc" +) + +type ServerCodec struct { + Rwc io.ReadWriteCloser + payloadLen int32 +} + +func (c *ServerCodec) ReadRequestHeader(r *rpc.Request) error { + return parseHeader(c.Rwc, &r.Seq, &r.ServiceMethod, &c.payloadLen) +} + +func (c *ServerCodec) ReadRequestBody(data interface{}) error { + b := make([]byte, c.payloadLen) + n, err := c.Rwc.Read(b) + if err != nil { + log.Fatal("server", err) + } + if n != int(c.payloadLen) { + return errors.New("ServerCodec unable to read request.") + } + + if data == nil { + // If data is nil, discard this request. + return nil + } + query := data.(*Query) + query.Data = b + return nil +} + +func (c *ServerCodec) WriteResponse(resp *rpc.Response, data interface{}) error { + if len(resp.Error) > 0 { + log.Fatal("Response has error: " + resp.Error) + } + if data == nil { + log.Fatal("Worker write response data is nil") + } + reply, ok := data.(*Reply) + if !ok { + log.Fatal("Unable to convert to reply") + } + + if err := writeHeader(c.Rwc, resp.Seq, + resp.ServiceMethod, reply.Data); err != nil { + return err + } + + _, err := c.Rwc.Write(reply.Data) + return err +} + +func (c *ServerCodec) Close() error { + return c.Rwc.Close() +} diff --git a/server/main.go b/server/main.go index e06c931c4e252f9115ef04b7382df167773bfc04..86e2d765295002003af3964bae276f2dd2ee92c6 100644 --- a/server/main.go +++ b/server/main.go @@ -141,6 +141,7 @@ func main() { posting.Init(clog) worker.Init(ps) + worker.Connect() uid.Init(ps) http.HandleFunc("/query", queryHandler(ps)) diff --git a/task.fbs b/task.fbs index 39e877cff06b507f36cf5185f73bb788384080b3..2e3acbb08719ae83954a5b6c0620f3a9aca6406b 100644 --- a/task.fbs +++ b/task.fbs @@ -9,6 +9,10 @@ table Value { val:[ubyte]; } +table XidList { + xids:[string]; +} + table UidList { uids:[ulong]; } diff --git a/task/XidList.go b/task/XidList.go new file mode 100644 index 0000000000000000000000000000000000000000..90798a8d87ee3f03f1d2a418534b862ccda54e37 --- /dev/null +++ b/task/XidList.go @@ -0,0 +1,38 @@ +// automatically generated, do not modify + +package task + +import ( + flatbuffers "github.com/google/flatbuffers/go" +) +type XidList struct { + _tab flatbuffers.Table +} + +func (rcv *XidList) Init(buf []byte, i flatbuffers.UOffsetT) { + rcv._tab.Bytes = buf + rcv._tab.Pos = i +} + +func (rcv *XidList) Xids(j int) []byte { + o := flatbuffers.UOffsetT(rcv._tab.Offset(4)) + if o != 0 { + a := rcv._tab.Vector(o) + return rcv._tab.ByteVector(a + flatbuffers.UOffsetT(j * 4)) + } + return nil +} + +func (rcv *XidList) XidsLength() int { + o := flatbuffers.UOffsetT(rcv._tab.Offset(4)) + if o != 0 { + return rcv._tab.VectorLen(o) + } + return 0 +} + +func XidListStart(builder *flatbuffers.Builder) { builder.StartObject(1) } +func XidListAddXids(builder *flatbuffers.Builder, xids flatbuffers.UOffsetT) { builder.PrependUOffsetTSlot(0, flatbuffers.UOffsetT(xids), 0) } +func XidListStartXidsVector(builder *flatbuffers.Builder, numElems int) flatbuffers.UOffsetT { return builder.StartVector(4, numElems, 4) +} +func XidListEnd(builder *flatbuffers.Builder) flatbuffers.UOffsetT { return builder.EndObject() } diff --git a/worker/worker.go b/worker/worker.go index 5538e0fa41dd35188fdaa0e8546ca8d885530407..5966a632afe1286ea589ac28d7e7462a59525396 100644 --- a/worker/worker.go +++ b/worker/worker.go @@ -1,6 +1,13 @@ package worker import ( + "flag" + "io" + "net" + "net/rpc" + "strings" + + "github.com/dgraph-io/dgraph/conn" "github.com/dgraph-io/dgraph/posting" "github.com/dgraph-io/dgraph/store" "github.com/dgraph-io/dgraph/task" @@ -8,12 +15,53 @@ import ( "github.com/google/flatbuffers/go" ) +var workers = flag.String("workers", "", + "Comma separated list of IP addresses of workers") +var workerPort = flag.String("workerport", ":12345", + "Port used by worker for internal communication.") + +var glog = x.Log("worker") var dataStore *store.Store +var pools []*conn.Pool func Init(ps *store.Store) { dataStore = ps } +func Connect() { + w := new(Worker) + if err := rpc.Register(w); err != nil { + glog.Fatal(err) + } + if err := runServer(*workerPort); err != nil { + glog.Fatal(err) + } + + addrs := strings.Split(*workers, ",") + var pools []*conn.Pool + for _, addr := range addrs { + if len(addr) == 0 { + continue + } + pool := conn.NewPool(addr, 5) + client, err := pool.Get() + if err != nil { + glog.Fatal(err) + } + query := new(conn.Query) + query.Data = []byte("hello") + reply := new(conn.Reply) + if err = client.Call("Worker.Hello", query, reply); err != nil { + glog.WithField("call", "Worker.Hello").Fatal(err) + } + glog.WithField("reply", string(reply.Data)).WithField("addr", addr). + Info("Got reply from server") + pools = append(pools, pool) + } + + glog.Info("Server started. Clients connected.") +} + func ProcessTask(query []byte) (result []byte, rerr error) { uo := flatbuffers.GetUOffsetT(query) q := new(task.Query) @@ -78,3 +126,48 @@ func NewQuery(attr string, uids []uint64) []byte { b.Finish(qend) return b.Bytes[b.Head():] } + +type Worker struct { +} + +func (w *Worker) Hello(query *conn.Query, reply *conn.Reply) error { + if string(query.Data) == "hello" { + reply.Data = []byte("Oh hello there!") + } else { + reply.Data = []byte("Hey stranger!") + } + return nil +} + +func serveRequests(irwc io.ReadWriteCloser) { + for { + sc := &conn.ServerCodec{ + Rwc: irwc, + } + rpc.ServeRequest(sc) + } +} + +func runServer(address string) error { + ln, err := net.Listen("tcp", address) + if err != nil { + glog.Fatalf("While running server: %v", err) + return err + } + glog.WithField("address", ln.Addr()).Info("Worker listening") + + go func() { + for { + cxn, err := ln.Accept() + if err != nil { + glog.Fatalf("listen(%q): %s\n", address, err) + return + } + glog.WithField("local", cxn.LocalAddr()). + WithField("remote", cxn.RemoteAddr()). + Debug("Worker accepted connection") + go serveRequests(cxn) + } + }() + return nil +} diff --git a/x/x.go b/x/x.go index a9e29e4860e3cd8f4404461b96c542decf7dbf31..2ae25556b65d41f68cb9659acbb58cdc54924931 100644 --- a/x/x.go +++ b/x/x.go @@ -39,6 +39,12 @@ type DirectedEdge struct { Timestamp time.Time } +func SetError(prev *error, n error) { + if prev == nil { + prev = &n + } +} + func Log(p string) *logrus.Entry { l := logrus.WithFields(logrus.Fields{ "package": p,