From c4629b907702748694712637cbef1fb2c1f15d07 Mon Sep 17 00:00:00 2001 From: Manish R Jain <manish@dgraph.io> Date: Wed, 8 Jun 2016 21:27:38 +1000 Subject: [PATCH] Switch from Golang RPC with custom encoding to using GRPC with custom encoding. This helps us use the amazing context feature of grpc for RPC tracing and client to server enforced deadlines. Also upgrade the protobuf package so we can compile proto3. --- conn/client.go | 69 ------ conn/codec.go | 78 ------ conn/codec_test.go | 161 ------------ conn/pool.go | 110 --------- conn/server.go | 75 ------ .../github.com/golang/protobuf/proto/clone.go | 12 +- .../golang/protobuf/proto/decode.go | 7 +- .../golang/protobuf/proto/encode.go | 52 +++- .../github.com/golang/protobuf/proto/equal.go | 26 +- .../golang/protobuf/proto/extensions.go | 196 +++++++++++++-- .../github.com/golang/protobuf/proto/lib.go | 4 + .../golang/protobuf/proto/message_set.go | 43 +++- .../golang/protobuf/proto/pointer_reflect.go | 5 + .../golang/protobuf/proto/pointer_unsafe.go | 4 + .../golang/protobuf/proto/properties.go | 24 +- .../github.com/golang/protobuf/proto/text.go | 13 +- .../golang/protobuf/proto/text_parser.go | 66 ++--- vendor/vendor.json | 6 +- worker/README.md | 5 + worker/assign.go | 22 +- worker/conn.go | 71 ++++++ worker/mutation.go | 23 +- worker/payload.pb.go | 231 ++++++++++++++++++ worker/payload.proto | 14 ++ worker/task.go | 16 +- worker/worker.go | 150 +++++------- 26 files changed, 797 insertions(+), 686 deletions(-) delete mode 100644 conn/client.go delete mode 100644 conn/codec.go delete mode 100644 conn/codec_test.go delete mode 100644 conn/pool.go delete mode 100644 conn/server.go create mode 100644 worker/README.md create mode 100644 worker/conn.go create mode 100644 worker/payload.pb.go create mode 100644 worker/payload.proto diff --git a/conn/client.go b/conn/client.go deleted file mode 100644 index 47bee518..00000000 --- a/conn/client.go +++ /dev/null @@ -1,69 +0,0 @@ -/* - * Copyright 2016 DGraph Labs, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package conn - -import ( - "errors" - "fmt" - "io" - "log" - "net/rpc" -) - -type ClientCodec struct { - Rwc io.ReadWriteCloser - payloadLen int32 -} - -func (c *ClientCodec) WriteRequest(r *rpc.Request, body interface{}) error { - if body == nil { - return fmt.Errorf("Nil request body from client.") - } - - query := body.(*Query) - if err := writeHeader(c.Rwc, r.Seq, r.ServiceMethod, query.Data); err != nil { - return err - } - n, err := c.Rwc.Write(query.Data) - if n != len(query.Data) { - return errors.New("Unable to write payload.") - } - return err -} - -func (c *ClientCodec) ReadResponseHeader(r *rpc.Response) error { - if len(r.Error) > 0 { - log.Fatal("client got response error: " + r.Error) - } - if err := parseHeader(c.Rwc, &r.Seq, - &r.ServiceMethod, &c.payloadLen); err != nil { - return err - } - return nil -} - -func (c *ClientCodec) ReadResponseBody(body interface{}) error { - buf := make([]byte, c.payloadLen) - _, err := io.ReadFull(c.Rwc, buf) - reply := body.(*Reply) - reply.Data = buf - return err -} - -func (c *ClientCodec) Close() error { - return c.Rwc.Close() -} diff --git a/conn/codec.go b/conn/codec.go deleted file mode 100644 index ec921739..00000000 --- a/conn/codec.go +++ /dev/null @@ -1,78 +0,0 @@ -/* - * Copyright 2016 DGraph Labs, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package conn - -import ( - "bytes" - "encoding/binary" - "fmt" - "io" - - "github.com/dgraph-io/dgraph/x" -) - -type Query struct { - Data []byte -} - -type Reply struct { - Data []byte - // TODO(manishrjain): Add an error here. - // Error string -} - -func writeHeader(rwc io.ReadWriteCloser, seq uint64, - method string, data []byte) error { - - var bh bytes.Buffer - var rerr error - - x.SetError(&rerr, binary.Write(&bh, binary.LittleEndian, seq)) - x.SetError(&rerr, binary.Write(&bh, binary.LittleEndian, int32(len(method)))) - x.SetError(&rerr, binary.Write(&bh, binary.LittleEndian, int32(len(data)))) - _, err := bh.Write([]byte(method)) - x.SetError(&rerr, err) - if rerr != nil { - return rerr - } - _, err = rwc.Write(bh.Bytes()) - return err -} - -func parseHeader(rwc io.ReadWriteCloser, seq *uint64, - method *string, plen *int32) error { - - var err error - var sz int32 - x.SetError(&err, binary.Read(rwc, binary.LittleEndian, seq)) - x.SetError(&err, binary.Read(rwc, binary.LittleEndian, &sz)) - x.SetError(&err, binary.Read(rwc, binary.LittleEndian, plen)) - if err != nil { - return err - } - - buf := make([]byte, sz) - n, err := rwc.Read(buf) - if err != nil { - return err - } - if n != int(sz) { - return fmt.Errorf("Expected: %v. Got: %v\n", sz, n) - } - *method = string(buf) - return nil -} diff --git a/conn/codec_test.go b/conn/codec_test.go deleted file mode 100644 index d8ff4eba..00000000 --- a/conn/codec_test.go +++ /dev/null @@ -1,161 +0,0 @@ -/* - * Copyright 2016 DGraph Labs, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package conn - -import ( - "bytes" - "net/rpc" - "testing" -) - -type buf struct { - data chan byte -} - -func newBuf() *buf { - b := new(buf) - b.data = make(chan byte, 10000) - return b -} - -func (b *buf) Read(p []byte) (n int, err error) { - for i := 0; i < len(p); i++ { - p[i] = <-b.data - } - return len(p), nil -} - -func (b *buf) Write(p []byte) (n int, err error) { - for i := 0; i < len(p); i++ { - b.data <- p[i] - } - return len(p), nil -} - -func (b *buf) Close() error { - close(b.data) - return nil -} - -func TestWriteAndParseHeader(t *testing.T) { - b := newBuf() - data := []byte("oh hey") - if err := writeHeader(b, 11, "testing.T", data); err != nil { - t.Error(err) - t.Fail() - } - var seq uint64 - var method string - var plen int32 - if err := parseHeader(b, &seq, &method, &plen); err != nil { - t.Error(err) - t.Fail() - } - if seq != 11 { - t.Errorf("Sequence number. Expected 11. Got: %v", seq) - } - if method != "testing.T" { - t.Errorf("Method name. Expected: testing.T. Got: %v", method) - } - if plen != int32(len(data)) { - t.Errorf("Payload length. Expected: %v. Got: %v", len(data), plen) - } -} - -func TestClientToServer(t *testing.T) { - b := newBuf() - cc := &ClientCodec{ - Rwc: b, - } - sc := &ServerCodec{ - Rwc: b, - } - - r := &rpc.Request{ - ServiceMethod: "Test.ClientServer", - Seq: 11, - } - - query := new(Query) - query.Data = []byte("iamaquery") - if err := cc.WriteRequest(r, query); err != nil { - t.Error(err) - } - - sr := new(rpc.Request) - if err := sc.ReadRequestHeader(sr); err != nil { - t.Error(err) - } - if sr.Seq != r.Seq { - t.Errorf("RPC Seq. Expected: %v. Got: %v", r.Seq, sr.Seq) - } - if sr.ServiceMethod != r.ServiceMethod { - t.Errorf("ServiceMethod. Expected: %v. Got: %v", - r.ServiceMethod, sr.ServiceMethod) - } - - squery := new(Query) - if err := sc.ReadRequestBody(squery); err != nil { - t.Error(err) - } - if !bytes.Equal(squery.Data, query.Data) { - t.Errorf("Queries don't match. Expected: %v Got: %v", - string(query.Data), string(squery.Data)) - } -} - -func TestServerToClient(t *testing.T) { - b := newBuf() - cc := &ClientCodec{ - Rwc: b, - } - sc := &ServerCodec{ - Rwc: b, - } - - r := &rpc.Response{ - ServiceMethod: "Test.ClientServer", - Seq: 11, - } - - reply := new(Reply) - reply.Data = []byte("iamareply") - if err := sc.WriteResponse(r, reply); err != nil { - t.Error(err) - } - - cr := new(rpc.Response) - if err := cc.ReadResponseHeader(cr); err != nil { - t.Error(err) - } - if cr.Seq != r.Seq { - t.Errorf("RPC Seq. Expected: %v. Got: %v", r.Seq, cr.Seq) - } - if cr.ServiceMethod != r.ServiceMethod { - t.Errorf("ServiceMethod. Expected: %v. Got: %v", - r.ServiceMethod, cr.ServiceMethod) - } - - creply := new(Reply) - if err := cc.ReadResponseBody(creply); err != nil { - t.Error(err) - } - if !bytes.Equal(creply.Data, reply.Data) { - t.Errorf("Replies don't match. Expected: %v Got: %v", - string(reply.Data), string(creply.Data)) - } -} diff --git a/conn/pool.go b/conn/pool.go deleted file mode 100644 index 76a107e5..00000000 --- a/conn/pool.go +++ /dev/null @@ -1,110 +0,0 @@ -/* - * Copyright 2016 DGraph Labs, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package conn - -import ( - "net" - "net/rpc" - "strings" - "time" - - "github.com/dgraph-io/dgraph/x" -) - -var glog = x.Log("conn") - -type Pool struct { - clients chan *rpc.Client - Addr string -} - -func NewPool(addr string, maxCap int) *Pool { - p := new(Pool) - p.Addr = addr - p.clients = make(chan *rpc.Client, maxCap) - client, err := p.dialNew() - if err != nil { - glog.Fatal(err) - return nil - } - p.clients <- client - return p -} - -func (p *Pool) dialNew() (*rpc.Client, error) { - d := &net.Dialer{ - Timeout: 3 * time.Minute, - } - var nconn net.Conn - var err error - // This loop will retry for 10 minutes before giving up. - for i := 0; i < 60; i++ { - nconn, err = d.Dial("tcp", p.Addr) - if err == nil { - break - } - if !strings.Contains(err.Error(), "refused") { - break - } - - glog.WithField("error", err).WithField("addr", p.Addr). - Info("Retrying connection...") - time.Sleep(10 * time.Second) - } - if err != nil { - return nil, err - } - cc := &ClientCodec{ - Rwc: nconn, - } - return rpc.NewClientWithCodec(cc), nil -} - -func (p *Pool) Call(serviceMethod string, args interface{}, - reply interface{}) error { - - client, err := p.get() - if err != nil { - return err - } - if err = client.Call(serviceMethod, args, reply); err != nil { - return err - } - - select { - case p.clients <- client: - return nil - default: - return client.Close() - } -} - -func (p *Pool) get() (*rpc.Client, error) { - select { - case client := <-p.clients: - return client, nil - default: - return p.dialNew() - } -} - -func (p *Pool) Close() error { - // We're not doing a clean exit here. A clean exit here would require - // synchronization, which seems unnecessary for now. But, we should - // add one if required later. - return nil -} diff --git a/conn/server.go b/conn/server.go deleted file mode 100644 index 406ea9bc..00000000 --- a/conn/server.go +++ /dev/null @@ -1,75 +0,0 @@ -/* - * Copyright 2016 DGraph Labs, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package conn - -import ( - "io" - "log" - "net/rpc" -) - -type ServerCodec struct { - Rwc io.ReadWriteCloser - payloadLen int32 -} - -func (c *ServerCodec) ReadRequestHeader(r *rpc.Request) error { - return parseHeader(c.Rwc, &r.Seq, &r.ServiceMethod, &c.payloadLen) -} - -func (c *ServerCodec) ReadRequestBody(data interface{}) error { - b := make([]byte, c.payloadLen) - _, err := io.ReadFull(c.Rwc, b) - if err != nil { - return err - } - - if data == nil { - // If data is nil, discard this request. - return nil - } - query := data.(*Query) - query.Data = b - return nil -} - -func (c *ServerCodec) WriteResponse(resp *rpc.Response, - data interface{}) error { - - if len(resp.Error) > 0 { - log.Fatal("Response has error: " + resp.Error) - } - if data == nil { - log.Fatal("Worker write response data is nil") - } - reply, ok := data.(*Reply) - if !ok { - log.Fatal("Unable to convert to reply") - } - - if err := writeHeader(c.Rwc, resp.Seq, - resp.ServiceMethod, reply.Data); err != nil { - return err - } - - _, err := c.Rwc.Write(reply.Data) - return err -} - -func (c *ServerCodec) Close() error { - return c.Rwc.Close() -} diff --git a/vendor/github.com/golang/protobuf/proto/clone.go b/vendor/github.com/golang/protobuf/proto/clone.go index e98ddec9..e392575b 100644 --- a/vendor/github.com/golang/protobuf/proto/clone.go +++ b/vendor/github.com/golang/protobuf/proto/clone.go @@ -84,9 +84,15 @@ func mergeStruct(out, in reflect.Value) { mergeAny(out.Field(i), in.Field(i), false, sprop.Prop[i]) } - if emIn, ok := in.Addr().Interface().(extendableProto); ok { - emOut := out.Addr().Interface().(extendableProto) - mergeExtension(emOut.ExtensionMap(), emIn.ExtensionMap()) + if emIn, ok := extendable(in.Addr().Interface()); ok { + emOut, _ := extendable(out.Addr().Interface()) + mIn, muIn := emIn.extensionsRead() + if mIn != nil { + mOut := emOut.extensionsWrite() + muIn.Lock() + mergeExtension(mOut, mIn) + muIn.Unlock() + } } uf := in.FieldByName("XXX_unrecognized") diff --git a/vendor/github.com/golang/protobuf/proto/decode.go b/vendor/github.com/golang/protobuf/proto/decode.go index f94b9f41..07288a25 100644 --- a/vendor/github.com/golang/protobuf/proto/decode.go +++ b/vendor/github.com/golang/protobuf/proto/decode.go @@ -390,11 +390,12 @@ func (o *Buffer) unmarshalType(st reflect.Type, prop *StructProperties, is_group if !ok { // Maybe it's an extension? if prop.extendable { - if e := structPointer_Interface(base, st).(extendableProto); isExtensionField(e, int32(tag)) { + if e, _ := extendable(structPointer_Interface(base, st)); isExtensionField(e, int32(tag)) { if err = o.skip(st, tag, wire); err == nil { - ext := e.ExtensionMap()[int32(tag)] // may be missing + extmap := e.extensionsWrite() + ext := extmap[int32(tag)] // may be missing ext.enc = append(ext.enc, o.buf[oi:o.index]...) - e.ExtensionMap()[int32(tag)] = ext + extmap[int32(tag)] = ext } continue } diff --git a/vendor/github.com/golang/protobuf/proto/encode.go b/vendor/github.com/golang/protobuf/proto/encode.go index eb7e0474..8c1b8fd1 100644 --- a/vendor/github.com/golang/protobuf/proto/encode.go +++ b/vendor/github.com/golang/protobuf/proto/encode.go @@ -70,6 +70,10 @@ var ( // ErrNil is the error returned if Marshal is called with nil. ErrNil = errors.New("proto: Marshal called with nil") + + // ErrTooLarge is the error returned if Marshal is called with a + // message that encodes to >2GB. + ErrTooLarge = errors.New("proto: message encodes to over 2 GB") ) // The fundamental encoders that put bytes on the wire. @@ -78,6 +82,10 @@ var ( const maxVarintBytes = 10 // maximum length of a varint +// maxMarshalSize is the largest allowed size of an encoded protobuf, +// since C++ and Java use signed int32s for the size. +const maxMarshalSize = 1<<31 - 1 + // EncodeVarint returns the varint encoding of x. // This is the format for the // int32, int64, uint32, uint64, bool, and enum @@ -277,6 +285,9 @@ func (p *Buffer) Marshal(pb Message) error { stats.Encode++ } + if len(p.buf) > maxMarshalSize { + return ErrTooLarge + } return err } @@ -1062,10 +1073,25 @@ func size_slice_struct_group(p *Properties, base structPointer) (n int) { // Encode an extension map. func (o *Buffer) enc_map(p *Properties, base structPointer) error { - v := *structPointer_ExtMap(base, p.field) - if err := encodeExtensionMap(v); err != nil { + exts := structPointer_ExtMap(base, p.field) + if err := encodeExtensionsMap(*exts); err != nil { return err } + + return o.enc_map_body(*exts) +} + +func (o *Buffer) enc_exts(p *Properties, base structPointer) error { + exts := structPointer_Extensions(base, p.field) + if err := encodeExtensions(exts); err != nil { + return err + } + v, _ := exts.extensionsRead() + + return o.enc_map_body(v) +} + +func (o *Buffer) enc_map_body(v map[int32]Extension) error { // Fast-path for common cases: zero or one extensions. if len(v) <= 1 { for _, e := range v { @@ -1088,8 +1114,13 @@ func (o *Buffer) enc_map(p *Properties, base structPointer) error { } func size_map(p *Properties, base structPointer) int { - v := *structPointer_ExtMap(base, p.field) - return sizeExtensionMap(v) + v := structPointer_ExtMap(base, p.field) + return extensionsMapSize(*v) +} + +func size_exts(p *Properties, base structPointer) int { + v := structPointer_Extensions(base, p.field) + return extensionsSize(v) } // Encode a map field. @@ -1118,7 +1149,7 @@ func (o *Buffer) enc_new_map(p *Properties, base structPointer) error { if err := p.mkeyprop.enc(o, p.mkeyprop, keybase); err != nil { return err } - if err := p.mvalprop.enc(o, p.mvalprop, valbase); err != nil { + if err := p.mvalprop.enc(o, p.mvalprop, valbase); err != nil && err != ErrNil { return err } return nil @@ -1128,11 +1159,6 @@ func (o *Buffer) enc_new_map(p *Properties, base structPointer) error { for _, key := range v.MapKeys() { val := v.MapIndex(key) - // The only illegal map entry values are nil message pointers. - if val.Kind() == reflect.Ptr && val.IsNil() { - return errors.New("proto: map has nil element") - } - keycopy.Set(key) valcopy.Set(val) @@ -1220,6 +1246,9 @@ func (o *Buffer) enc_struct(prop *StructProperties, base structPointer) error { return err } } + if len(o.buf) > maxMarshalSize { + return ErrTooLarge + } } } @@ -1236,6 +1265,9 @@ func (o *Buffer) enc_struct(prop *StructProperties, base structPointer) error { // Add unrecognized fields at the end. if prop.unrecField.IsValid() { v := *structPointer_Bytes(base, prop.unrecField) + if len(o.buf)+len(v) > maxMarshalSize { + return ErrTooLarge + } if len(v) > 0 { o.buf = append(o.buf, v...) } diff --git a/vendor/github.com/golang/protobuf/proto/equal.go b/vendor/github.com/golang/protobuf/proto/equal.go index f5db1def..8b16f951 100644 --- a/vendor/github.com/golang/protobuf/proto/equal.go +++ b/vendor/github.com/golang/protobuf/proto/equal.go @@ -121,9 +121,16 @@ func equalStruct(v1, v2 reflect.Value) bool { } } + if em1 := v1.FieldByName("XXX_InternalExtensions"); em1.IsValid() { + em2 := v2.FieldByName("XXX_InternalExtensions") + if !equalExtensions(v1.Type(), em1.Interface().(XXX_InternalExtensions), em2.Interface().(XXX_InternalExtensions)) { + return false + } + } + if em1 := v1.FieldByName("XXX_extensions"); em1.IsValid() { em2 := v2.FieldByName("XXX_extensions") - if !equalExtensions(v1.Type(), em1.Interface().(map[int32]Extension), em2.Interface().(map[int32]Extension)) { + if !equalExtMap(v1.Type(), em1.Interface().(map[int32]Extension), em2.Interface().(map[int32]Extension)) { return false } } @@ -184,6 +191,13 @@ func equalAny(v1, v2 reflect.Value, prop *Properties) bool { } return true case reflect.Ptr: + // Maps may have nil values in them, so check for nil. + if v1.IsNil() && v2.IsNil() { + return true + } + if v1.IsNil() != v2.IsNil() { + return false + } return equalAny(v1.Elem(), v2.Elem(), prop) case reflect.Slice: if v1.Type().Elem().Kind() == reflect.Uint8 { @@ -223,8 +237,14 @@ func equalAny(v1, v2 reflect.Value, prop *Properties) bool { } // base is the struct type that the extensions are based on. -// em1 and em2 are extension maps. -func equalExtensions(base reflect.Type, em1, em2 map[int32]Extension) bool { +// x1 and x2 are InternalExtensions. +func equalExtensions(base reflect.Type, x1, x2 XXX_InternalExtensions) bool { + em1, _ := x1.extensionsRead() + em2, _ := x2.extensionsRead() + return equalExtMap(base, em1, em2) +} + +func equalExtMap(base reflect.Type, em1, em2 map[int32]Extension) bool { if len(em1) != len(em2) { return false } diff --git a/vendor/github.com/golang/protobuf/proto/extensions.go b/vendor/github.com/golang/protobuf/proto/extensions.go index 054f4f1d..9f484f53 100644 --- a/vendor/github.com/golang/protobuf/proto/extensions.go +++ b/vendor/github.com/golang/protobuf/proto/extensions.go @@ -52,14 +52,99 @@ type ExtensionRange struct { Start, End int32 // both inclusive } -// extendableProto is an interface implemented by any protocol buffer that may be extended. +// extendableProto is an interface implemented by any protocol buffer generated by the current +// proto compiler that may be extended. type extendableProto interface { + Message + ExtensionRangeArray() []ExtensionRange + extensionsWrite() map[int32]Extension + extensionsRead() (map[int32]Extension, sync.Locker) +} + +// extendableProtoV1 is an interface implemented by a protocol buffer generated by the previous +// version of the proto compiler that may be extended. +type extendableProtoV1 interface { Message ExtensionRangeArray() []ExtensionRange ExtensionMap() map[int32]Extension } +// extensionAdapter is a wrapper around extendableProtoV1 that implements extendableProto. +type extensionAdapter struct { + extendableProtoV1 +} + +func (e extensionAdapter) extensionsWrite() map[int32]Extension { + return e.ExtensionMap() +} + +func (e extensionAdapter) extensionsRead() (map[int32]Extension, sync.Locker) { + return e.ExtensionMap(), notLocker{} +} + +// notLocker is a sync.Locker whose Lock and Unlock methods are nops. +type notLocker struct{} + +func (n notLocker) Lock() {} +func (n notLocker) Unlock() {} + +// extendable returns the extendableProto interface for the given generated proto message. +// If the proto message has the old extension format, it returns a wrapper that implements +// the extendableProto interface. +func extendable(p interface{}) (extendableProto, bool) { + if ep, ok := p.(extendableProto); ok { + return ep, ok + } + if ep, ok := p.(extendableProtoV1); ok { + return extensionAdapter{ep}, ok + } + return nil, false +} + +// XXX_InternalExtensions is an internal representation of proto extensions. +// +// Each generated message struct type embeds an anonymous XXX_InternalExtensions field, +// thus gaining the unexported 'extensions' method, which can be called only from the proto package. +// +// The methods of XXX_InternalExtensions are not concurrency safe in general, +// but calls to logically read-only methods such as has and get may be executed concurrently. +type XXX_InternalExtensions struct { + // The struct must be indirect so that if a user inadvertently copies a + // generated message and its embedded XXX_InternalExtensions, they + // avoid the mayhem of a copied mutex. + // + // The mutex serializes all logically read-only operations to p.extensionMap. + // It is up to the client to ensure that write operations to p.extensionMap are + // mutually exclusive with other accesses. + p *struct { + mu sync.Mutex + extensionMap map[int32]Extension + } +} + +// extensionsWrite returns the extension map, creating it on first use. +func (e *XXX_InternalExtensions) extensionsWrite() map[int32]Extension { + if e.p == nil { + e.p = new(struct { + mu sync.Mutex + extensionMap map[int32]Extension + }) + e.p.extensionMap = make(map[int32]Extension) + } + return e.p.extensionMap +} + +// extensionsRead returns the extensions map for read-only use. It may be nil. +// The caller must hold the returned mutex's lock when accessing Elements within the map. +func (e *XXX_InternalExtensions) extensionsRead() (map[int32]Extension, sync.Locker) { + if e.p == nil { + return nil, nil + } + return e.p.extensionMap, &e.p.mu +} + var extendableProtoType = reflect.TypeOf((*extendableProto)(nil)).Elem() +var extendableProtoV1Type = reflect.TypeOf((*extendableProtoV1)(nil)).Elem() // ExtensionDesc represents an extension specification. // Used in generated code from the protocol compiler. @@ -92,8 +177,13 @@ type Extension struct { } // SetRawExtension is for testing only. -func SetRawExtension(base extendableProto, id int32, b []byte) { - base.ExtensionMap()[id] = Extension{enc: b} +func SetRawExtension(base Message, id int32, b []byte) { + epb, ok := extendable(base) + if !ok { + return + } + extmap := epb.extensionsWrite() + extmap[id] = Extension{enc: b} } // isExtensionField returns true iff the given field number is in an extension range. @@ -108,8 +198,12 @@ func isExtensionField(pb extendableProto, field int32) bool { // checkExtensionTypes checks that the given extension is valid for pb. func checkExtensionTypes(pb extendableProto, extension *ExtensionDesc) error { + var pbi interface{} = pb // Check the extended type. - if a, b := reflect.TypeOf(pb), reflect.TypeOf(extension.ExtendedType); a != b { + if ea, ok := pbi.(extensionAdapter); ok { + pbi = ea.extendableProtoV1 + } + if a, b := reflect.TypeOf(pbi), reflect.TypeOf(extension.ExtendedType); a != b { return errors.New("proto: bad extended type; " + b.String() + " does not extend " + a.String()) } // Check the range. @@ -155,8 +249,19 @@ func extensionProperties(ed *ExtensionDesc) *Properties { return prop } -// encodeExtensionMap encodes any unmarshaled (unencoded) extensions in m. -func encodeExtensionMap(m map[int32]Extension) error { +// encode encodes any unmarshaled (unencoded) extensions in e. +func encodeExtensions(e *XXX_InternalExtensions) error { + m, mu := e.extensionsRead() + if m == nil { + return nil // fast path + } + mu.Lock() + defer mu.Unlock() + return encodeExtensionsMap(m) +} + +// encode encodes any unmarshaled (unencoded) extensions in e. +func encodeExtensionsMap(m map[int32]Extension) error { for k, e := range m { if e.value == nil || e.desc == nil { // Extension is only in its encoded form. @@ -184,7 +289,17 @@ func encodeExtensionMap(m map[int32]Extension) error { return nil } -func sizeExtensionMap(m map[int32]Extension) (n int) { +func extensionsSize(e *XXX_InternalExtensions) (n int) { + m, mu := e.extensionsRead() + if m == nil { + return 0 + } + mu.Lock() + defer mu.Unlock() + return extensionsMapSize(m) +} + +func extensionsMapSize(m map[int32]Extension) (n int) { for _, e := range m { if e.value == nil || e.desc == nil { // Extension is only in its encoded form. @@ -209,26 +324,51 @@ func sizeExtensionMap(m map[int32]Extension) (n int) { } // HasExtension returns whether the given extension is present in pb. -func HasExtension(pb extendableProto, extension *ExtensionDesc) bool { +func HasExtension(pb Message, extension *ExtensionDesc) bool { // TODO: Check types, field numbers, etc.? - _, ok := pb.ExtensionMap()[extension.Field] + epb, ok := extendable(pb) + if !ok { + return false + } + extmap, mu := epb.extensionsRead() + if extmap == nil { + return false + } + mu.Lock() + _, ok = extmap[extension.Field] + mu.Unlock() return ok } // ClearExtension removes the given extension from pb. -func ClearExtension(pb extendableProto, extension *ExtensionDesc) { +func ClearExtension(pb Message, extension *ExtensionDesc) { + epb, ok := extendable(pb) + if !ok { + return + } // TODO: Check types, field numbers, etc.? - delete(pb.ExtensionMap(), extension.Field) + extmap := epb.extensionsWrite() + delete(extmap, extension.Field) } // GetExtension parses and returns the given extension of pb. // If the extension is not present and has no default value it returns ErrMissingExtension. -func GetExtension(pb extendableProto, extension *ExtensionDesc) (interface{}, error) { - if err := checkExtensionTypes(pb, extension); err != nil { +func GetExtension(pb Message, extension *ExtensionDesc) (interface{}, error) { + epb, ok := extendable(pb) + if !ok { + return nil, errors.New("proto: not an extendable proto") + } + + if err := checkExtensionTypes(epb, extension); err != nil { return nil, err } - emap := pb.ExtensionMap() + emap, mu := epb.extensionsRead() + if emap == nil { + return defaultExtensionValue(extension) + } + mu.Lock() + defer mu.Unlock() e, ok := emap[extension.Field] if !ok { // defaultExtensionValue returns the default value or @@ -332,10 +472,9 @@ func decodeExtension(b []byte, extension *ExtensionDesc) (interface{}, error) { // GetExtensions returns a slice of the extensions present in pb that are also listed in es. // The returned slice has the same length as es; missing extensions will appear as nil elements. func GetExtensions(pb Message, es []*ExtensionDesc) (extensions []interface{}, err error) { - epb, ok := pb.(extendableProto) + epb, ok := extendable(pb) if !ok { - err = errors.New("proto: not an extendable proto") - return + return nil, errors.New("proto: not an extendable proto") } extensions = make([]interface{}, len(es)) for i, e := range es { @@ -351,8 +490,12 @@ func GetExtensions(pb Message, es []*ExtensionDesc) (extensions []interface{}, e } // SetExtension sets the specified extension of pb to the specified value. -func SetExtension(pb extendableProto, extension *ExtensionDesc, value interface{}) error { - if err := checkExtensionTypes(pb, extension); err != nil { +func SetExtension(pb Message, extension *ExtensionDesc, value interface{}) error { + epb, ok := extendable(pb) + if !ok { + return errors.New("proto: not an extendable proto") + } + if err := checkExtensionTypes(epb, extension); err != nil { return err } typ := reflect.TypeOf(extension.ExtensionType) @@ -368,10 +511,23 @@ func SetExtension(pb extendableProto, extension *ExtensionDesc, value interface{ return fmt.Errorf("proto: SetExtension called with nil value of type %T", value) } - pb.ExtensionMap()[extension.Field] = Extension{desc: extension, value: value} + extmap := epb.extensionsWrite() + extmap[extension.Field] = Extension{desc: extension, value: value} return nil } +// ClearAllExtensions clears all extensions from pb. +func ClearAllExtensions(pb Message) { + epb, ok := extendable(pb) + if !ok { + return + } + m := epb.extensionsWrite() + for k := range m { + delete(m, k) + } +} + // A global registry of extensions. // The generated code will register the generated descriptors by calling RegisterExtension. diff --git a/vendor/github.com/golang/protobuf/proto/lib.go b/vendor/github.com/golang/protobuf/proto/lib.go index 0de8f8df..170b8e87 100644 --- a/vendor/github.com/golang/protobuf/proto/lib.go +++ b/vendor/github.com/golang/protobuf/proto/lib.go @@ -889,6 +889,10 @@ func isProto3Zero(v reflect.Value) bool { return false } +// ProtoPackageIsVersion2 is referenced from generated protocol buffer files +// to assert that that code is compatible with this version of the proto package. +const ProtoPackageIsVersion2 = true + // ProtoPackageIsVersion1 is referenced from generated protocol buffer files // to assert that that code is compatible with this version of the proto package. const ProtoPackageIsVersion1 = true diff --git a/vendor/github.com/golang/protobuf/proto/message_set.go b/vendor/github.com/golang/protobuf/proto/message_set.go index e25e01e6..fd982dec 100644 --- a/vendor/github.com/golang/protobuf/proto/message_set.go +++ b/vendor/github.com/golang/protobuf/proto/message_set.go @@ -149,9 +149,21 @@ func skipVarint(buf []byte) []byte { // MarshalMessageSet encodes the extension map represented by m in the message set wire format. // It is called by generated Marshal methods on protocol buffer messages with the message_set_wire_format option. -func MarshalMessageSet(m map[int32]Extension) ([]byte, error) { - if err := encodeExtensionMap(m); err != nil { - return nil, err +func MarshalMessageSet(exts interface{}) ([]byte, error) { + var m map[int32]Extension + switch exts := exts.(type) { + case *XXX_InternalExtensions: + if err := encodeExtensions(exts); err != nil { + return nil, err + } + m, _ = exts.extensionsRead() + case map[int32]Extension: + if err := encodeExtensionsMap(exts); err != nil { + return nil, err + } + m = exts + default: + return nil, errors.New("proto: not an extension map") } // Sort extension IDs to provide a deterministic encoding. @@ -178,7 +190,17 @@ func MarshalMessageSet(m map[int32]Extension) ([]byte, error) { // UnmarshalMessageSet decodes the extension map encoded in buf in the message set wire format. // It is called by generated Unmarshal methods on protocol buffer messages with the message_set_wire_format option. -func UnmarshalMessageSet(buf []byte, m map[int32]Extension) error { +func UnmarshalMessageSet(buf []byte, exts interface{}) error { + var m map[int32]Extension + switch exts := exts.(type) { + case *XXX_InternalExtensions: + m = exts.extensionsWrite() + case map[int32]Extension: + m = exts + default: + return errors.New("proto: not an extension map") + } + ms := new(messageSet) if err := Unmarshal(buf, ms); err != nil { return err @@ -209,7 +231,16 @@ func UnmarshalMessageSet(buf []byte, m map[int32]Extension) error { // MarshalMessageSetJSON encodes the extension map represented by m in JSON format. // It is called by generated MarshalJSON methods on protocol buffer messages with the message_set_wire_format option. -func MarshalMessageSetJSON(m map[int32]Extension) ([]byte, error) { +func MarshalMessageSetJSON(exts interface{}) ([]byte, error) { + var m map[int32]Extension + switch exts := exts.(type) { + case *XXX_InternalExtensions: + m, _ = exts.extensionsRead() + case map[int32]Extension: + m = exts + default: + return nil, errors.New("proto: not an extension map") + } var b bytes.Buffer b.WriteByte('{') @@ -252,7 +283,7 @@ func MarshalMessageSetJSON(m map[int32]Extension) ([]byte, error) { // UnmarshalMessageSetJSON decodes the extension map encoded in buf in JSON format. // It is called by generated UnmarshalJSON methods on protocol buffer messages with the message_set_wire_format option. -func UnmarshalMessageSetJSON(buf []byte, m map[int32]Extension) error { +func UnmarshalMessageSetJSON(buf []byte, exts interface{}) error { // Common-case fast path. if len(buf) == 0 || bytes.Equal(buf, []byte("{}")) { return nil diff --git a/vendor/github.com/golang/protobuf/proto/pointer_reflect.go b/vendor/github.com/golang/protobuf/proto/pointer_reflect.go index 98991417..fb512e2e 100644 --- a/vendor/github.com/golang/protobuf/proto/pointer_reflect.go +++ b/vendor/github.com/golang/protobuf/proto/pointer_reflect.go @@ -139,6 +139,11 @@ func structPointer_StringSlice(p structPointer, f field) *[]string { return structPointer_ifield(p, f).(*[]string) } +// Extensions returns the address of an extension map field in the struct. +func structPointer_Extensions(p structPointer, f field) *XXX_InternalExtensions { + return structPointer_ifield(p, f).(*XXX_InternalExtensions) +} + // ExtMap returns the address of an extension map field in the struct. func structPointer_ExtMap(p structPointer, f field) *map[int32]Extension { return structPointer_ifield(p, f).(*map[int32]Extension) diff --git a/vendor/github.com/golang/protobuf/proto/pointer_unsafe.go b/vendor/github.com/golang/protobuf/proto/pointer_unsafe.go index ceece772..6b5567d4 100644 --- a/vendor/github.com/golang/protobuf/proto/pointer_unsafe.go +++ b/vendor/github.com/golang/protobuf/proto/pointer_unsafe.go @@ -126,6 +126,10 @@ func structPointer_StringSlice(p structPointer, f field) *[]string { } // ExtMap returns the address of an extension map field in the struct. +func structPointer_Extensions(p structPointer, f field) *XXX_InternalExtensions { + return (*XXX_InternalExtensions)(unsafe.Pointer(uintptr(p) + uintptr(f))) +} + func structPointer_ExtMap(p structPointer, f field) *map[int32]Extension { return (*map[int32]Extension)(unsafe.Pointer(uintptr(p) + uintptr(f))) } diff --git a/vendor/github.com/golang/protobuf/proto/properties.go b/vendor/github.com/golang/protobuf/proto/properties.go index 880eb22d..dd29683c 100644 --- a/vendor/github.com/golang/protobuf/proto/properties.go +++ b/vendor/github.com/golang/protobuf/proto/properties.go @@ -473,17 +473,13 @@ func (p *Properties) setEncAndDec(typ reflect.Type, f *reflect.StructField, lock p.dec = (*Buffer).dec_slice_int64 p.packedDec = (*Buffer).dec_slice_packed_int64 case reflect.Uint8: - p.enc = (*Buffer).enc_slice_byte p.dec = (*Buffer).dec_slice_byte - p.size = size_slice_byte - // This is a []byte, which is either a bytes field, - // or the value of a map field. In the latter case, - // we always encode an empty []byte, so we should not - // use the proto3 enc/size funcs. - // f == nil iff this is the key/value of a map field. - if p.proto3 && f != nil { + if p.proto3 { p.enc = (*Buffer).enc_proto3_slice_byte p.size = size_proto3_slice_byte + } else { + p.enc = (*Buffer).enc_slice_byte + p.size = size_slice_byte } case reflect.Float32, reflect.Float64: switch t2.Bits() { @@ -682,7 +678,8 @@ func getPropertiesLocked(t reflect.Type) *StructProperties { propertiesMap[t] = prop // build properties - prop.extendable = reflect.PtrTo(t).Implements(extendableProtoType) + prop.extendable = reflect.PtrTo(t).Implements(extendableProtoType) || + reflect.PtrTo(t).Implements(extendableProtoV1Type) prop.unrecField = invalidField prop.Prop = make([]*Properties, t.NumField()) prop.order = make([]int, t.NumField()) @@ -693,12 +690,15 @@ func getPropertiesLocked(t reflect.Type) *StructProperties { name := f.Name p.init(f.Type, name, f.Tag.Get("protobuf"), &f, false) - if f.Name == "XXX_extensions" { // special case + if f.Name == "XXX_InternalExtensions" { // special case + p.enc = (*Buffer).enc_exts + p.dec = nil // not needed + p.size = size_exts + } else if f.Name == "XXX_extensions" { // special case p.enc = (*Buffer).enc_map p.dec = nil // not needed p.size = size_map - } - if f.Name == "XXX_unrecognized" { // special case + } else if f.Name == "XXX_unrecognized" { // special case prop.unrecField = toField(&f) } oneof := f.Tag.Get("protobuf_oneof") // special case diff --git a/vendor/github.com/golang/protobuf/proto/text.go b/vendor/github.com/golang/protobuf/proto/text.go index 37c95357..8214ce32 100644 --- a/vendor/github.com/golang/protobuf/proto/text.go +++ b/vendor/github.com/golang/protobuf/proto/text.go @@ -455,7 +455,7 @@ func (tm *TextMarshaler) writeStruct(w *textWriter, sv reflect.Value) error { // Extensions (the XXX_extensions field). pv := sv.Addr() - if pv.Type().Implements(extendableProtoType) { + if _, ok := extendable(pv.Interface()); ok { if err := tm.writeExtensions(w, pv); err != nil { return err } @@ -513,7 +513,7 @@ func (tm *TextMarshaler) writeAny(w *textWriter, v reflect.Value, props *Propert switch v.Kind() { case reflect.Slice: // Should only be a []byte; repeated fields are handled in writeStruct. - if err := writeString(w, string(v.Interface().([]byte))); err != nil { + if err := writeString(w, string(v.Bytes())); err != nil { return err } case reflect.String: @@ -689,17 +689,22 @@ func (s int32Slice) Swap(i, j int) { s[i], s[j] = s[j], s[i] } // pv is assumed to be a pointer to a protocol message struct that is extendable. func (tm *TextMarshaler) writeExtensions(w *textWriter, pv reflect.Value) error { emap := extensionMaps[pv.Type().Elem()] - ep := pv.Interface().(extendableProto) + ep, _ := extendable(pv.Interface()) // Order the extensions by ID. // This isn't strictly necessary, but it will give us // canonical output, which will also make testing easier. - m := ep.ExtensionMap() + m, mu := ep.extensionsRead() + if m == nil { + return nil + } + mu.Lock() ids := make([]int32, 0, len(m)) for id := range m { ids = append(ids, id) } sort.Sort(int32Slice(ids)) + mu.Unlock() for _, extNum := range ids { ext := m[extNum] diff --git a/vendor/github.com/golang/protobuf/proto/text_parser.go b/vendor/github.com/golang/protobuf/proto/text_parser.go index b5fba5b2..0b8c59f7 100644 --- a/vendor/github.com/golang/protobuf/proto/text_parser.go +++ b/vendor/github.com/golang/protobuf/proto/text_parser.go @@ -550,7 +550,7 @@ func (p *textParser) readStruct(sv reflect.Value, terminator string) error { } reqFieldErr = err } - ep := sv.Addr().Interface().(extendableProto) + ep := sv.Addr().Interface().(Message) if !rep { SetExtension(ep, desc, ext.Interface()) } else { @@ -602,8 +602,9 @@ func (p *textParser) readStruct(sv reflect.Value, terminator string) error { // The map entry should be this sequence of tokens: // < key : KEY value : VALUE > - // Technically the "key" and "value" could come in any order, - // but in practice they won't. + // However, implementations may omit key or value, and technically + // we should support them in any order. See b/28924776 for a time + // this went wrong. tok := p.next() var terminator string @@ -615,32 +616,39 @@ func (p *textParser) readStruct(sv reflect.Value, terminator string) error { default: return p.errorf("expected '{' or '<', found %q", tok.value) } - if err := p.consumeToken("key"); err != nil { - return err - } - if err := p.consumeToken(":"); err != nil { - return err - } - if err := p.readAny(key, props.mkeyprop); err != nil { - return err - } - if err := p.consumeOptionalSeparator(); err != nil { - return err - } - if err := p.consumeToken("value"); err != nil { - return err - } - if err := p.checkForColon(props.mvalprop, dst.Type().Elem()); err != nil { - return err - } - if err := p.readAny(val, props.mvalprop); err != nil { - return err - } - if err := p.consumeOptionalSeparator(); err != nil { - return err - } - if err := p.consumeToken(terminator); err != nil { - return err + for { + tok := p.next() + if tok.err != nil { + return tok.err + } + if tok.value == terminator { + break + } + switch tok.value { + case "key": + if err := p.consumeToken(":"); err != nil { + return err + } + if err := p.readAny(key, props.mkeyprop); err != nil { + return err + } + if err := p.consumeOptionalSeparator(); err != nil { + return err + } + case "value": + if err := p.checkForColon(props.mvalprop, dst.Type().Elem()); err != nil { + return err + } + if err := p.readAny(val, props.mvalprop); err != nil { + return err + } + if err := p.consumeOptionalSeparator(); err != nil { + return err + } + default: + p.back() + return p.errorf(`expected "key", "value", or %q, found %q`, terminator, tok.value) + } } dst.SetMapIndex(key, val) diff --git a/vendor/vendor.json b/vendor/vendor.json index b7373bef..2ec827ea 100644 --- a/vendor/vendor.json +++ b/vendor/vendor.json @@ -15,10 +15,10 @@ "revisionTime": "2015-09-09T17:09:13Z" }, { - "checksumSHA1": "FczzogSoZcKU3h21tCUyHzMsnBY=", + "checksumSHA1": "Yz1gT/uux3UsjbGDG/OP7gZREY0=", "path": "github.com/golang/protobuf/proto", - "revision": "7cc19b78d562895b13596ddce7aafb59dd789318", - "revisionTime": "2016-04-25T21:53:00Z" + "revision": "3b06fc7a4cad73efce5fe6217ab6c33e7231ab4a", + "revisionTime": "2016-06-01T15:17:05Z" }, { "checksumSHA1": "bAXE13JoHB8GZ5svBJpzCWiuZBQ=", diff --git a/worker/README.md b/worker/README.md new file mode 100644 index 00000000..01550d83 --- /dev/null +++ b/worker/README.md @@ -0,0 +1,5 @@ +To update the protocol buffer definitions, run this from one directory above: + +``` +protoc -I worker worker/payload.proto --go_out=plugins=grpc:worker +``` diff --git a/worker/assign.go b/worker/assign.go index 99e24525..45ed5b5f 100644 --- a/worker/assign.go +++ b/worker/assign.go @@ -20,7 +20,8 @@ import ( "fmt" "sync" - "github.com/dgraph-io/dgraph/conn" + "golang.org/x/net/context" + "github.com/dgraph-io/dgraph/task" "github.com/dgraph-io/dgraph/uid" "github.com/google/flatbuffers/go" @@ -93,13 +94,13 @@ func getOrAssignUids( } func GetOrAssignUidsOverNetwork(xidToUid *map[string]uint64) (rerr error) { - query := new(conn.Query) + query := new(Payload) query.Data = createXidListBuffer(*xidToUid) uo := flatbuffers.GetUOffsetT(query.Data) xidList := new(task.XidList) xidList.Init(query.Data, uo) - reply := new(conn.Reply) + reply := new(Payload) if instanceIdx == 0 { uo := flatbuffers.GetUOffsetT(query.Data) xidList := new(task.XidList) @@ -109,13 +110,22 @@ func GetOrAssignUidsOverNetwork(xidToUid *map[string]uint64) (rerr error) { if rerr != nil { return rerr } + } else { pool := pools[0] - if err := pool.Call("Worker.GetOrAssign", query, reply); err != nil { - glog.WithField("method", "GetOrAssign").WithError(err). - Error("While getting uids") + conn, err := pool.Get() + if err != nil { + glog.WithError(err).Error("Unable to retrieve connection.") return err } + c := NewWorkerClient(conn) + + reply, rerr = c.GetOrAssign(context.Background(), query) + if rerr != nil { + glog.WithField("method", "GetOrAssign").WithError(rerr). + Error("While getting uids") + return rerr + } } uidList := new(task.UidList) diff --git a/worker/conn.go b/worker/conn.go new file mode 100644 index 00000000..482fecb9 --- /dev/null +++ b/worker/conn.go @@ -0,0 +1,71 @@ +package worker + +import ( + "log" + + "google.golang.org/grpc" +) + +type PayloadCodec struct{} + +func (cb *PayloadCodec) Marshal(v interface{}) ([]byte, error) { + p, ok := v.(*Payload) + if !ok { + log.Fatalf("Invalid type of struct: %+v", v) + } + return p.Data, nil +} + +func (cb *PayloadCodec) Unmarshal(data []byte, v interface{}) error { + p, ok := v.(*Payload) + if !ok { + log.Fatalf("Invalid type of struct: %+v", v) + } + p.Data = data + return nil +} + +func (cb *PayloadCodec) String() string { + return "worker.PayloadCodec" +} + +type Pool struct { + conns chan *grpc.ClientConn + Addr string +} + +func NewPool(addr string, maxCap int) *Pool { + p := new(Pool) + p.Addr = addr + p.conns = make(chan *grpc.ClientConn, maxCap) + conn, err := p.dialNew() + if err != nil { + glog.Fatal(err) + return nil + } + p.conns <- conn + return p +} + +func (p *Pool) dialNew() (*grpc.ClientConn, error) { + return grpc.Dial(p.Addr, grpc.WithInsecure(), grpc.WithInsecure(), + grpc.WithCodec(&PayloadCodec{})) +} + +func (p *Pool) Get() (*grpc.ClientConn, error) { + select { + case conn := <-p.conns: + return conn, nil + default: + return p.dialNew() + } +} + +func (p *Pool) Put(conn *grpc.ClientConn) error { + select { + case p.conns <- conn: + return nil + default: + return conn.Close() + } +} diff --git a/worker/mutation.go b/worker/mutation.go index 36a6666d..78c2a6a5 100644 --- a/worker/mutation.go +++ b/worker/mutation.go @@ -22,7 +22,8 @@ import ( "fmt" "sync" - "github.com/dgraph-io/dgraph/conn" + "golang.org/x/net/context" + "github.com/dgraph-io/dgraph/posting" "github.com/dgraph-io/dgraph/x" "github.com/dgryski/go-farm" @@ -72,7 +73,7 @@ func mutate(m *Mutations, left *Mutations) error { } func runMutate(idx int, m *Mutations, wg *sync.WaitGroup, - replies chan *conn.Reply, che chan error) { + replies chan *Payload, che chan error) { defer wg.Done() left := new(Mutations) @@ -81,17 +82,25 @@ func runMutate(idx int, m *Mutations, wg *sync.WaitGroup, return } - var err error pool := pools[idx] - query := new(conn.Query) + var err error + query := new(Payload) query.Data, err = m.Encode() if err != nil { che <- err return } - reply := new(conn.Reply) - if err := pool.Call("Worker.Mutate", query, reply); err != nil { + conn, err := pool.Get() + if err != nil { + che <- err + return + } + defer pool.Put(conn) + c := NewWorkerClient(conn) + + reply, err := c.Mutate(context.Background(), query) + if err != nil { glog.WithField("call", "Worker.Mutate"). WithField("addr", pool.Addr). WithError(err).Error("While calling mutate") @@ -116,7 +125,7 @@ func MutateOverNetwork( } var wg sync.WaitGroup - replies := make(chan *conn.Reply, numInstances) + replies := make(chan *Payload, numInstances) errors := make(chan error, numInstances) for idx, mu := range mutationArray { if mu == nil || len(mu.Set) == 0 { diff --git a/worker/payload.pb.go b/worker/payload.pb.go new file mode 100644 index 00000000..0b60e38b --- /dev/null +++ b/worker/payload.pb.go @@ -0,0 +1,231 @@ +// Code generated by protoc-gen-go. +// source: payload.proto +// DO NOT EDIT! + +/* +Package worker is a generated protocol buffer package. + +It is generated from these files: + payload.proto + +It has these top-level messages: + Payload +*/ +package worker + +import proto "github.com/golang/protobuf/proto" +import fmt "fmt" +import math "math" + +import ( + context "golang.org/x/net/context" + grpc "google.golang.org/grpc" +) + +// Reference imports to suppress errors if they are not otherwise used. +var _ = proto.Marshal +var _ = fmt.Errorf +var _ = math.Inf + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the proto package it is being compiled against. +// A compilation error at this line likely means your copy of the +// proto package needs to be updated. +const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package + +type Payload struct { + Data []byte `protobuf:"bytes,1,opt,name=Data,json=data,proto3" json:"Data,omitempty"` +} + +func (m *Payload) Reset() { *m = Payload{} } +func (m *Payload) String() string { return proto.CompactTextString(m) } +func (*Payload) ProtoMessage() {} +func (*Payload) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{0} } + +func init() { + proto.RegisterType((*Payload)(nil), "worker.Payload") +} + +// Reference imports to suppress errors if they are not otherwise used. +var _ context.Context +var _ grpc.ClientConn + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the grpc package it is being compiled against. +const _ = grpc.SupportPackageIsVersion2 + +// Client API for Worker service + +type WorkerClient interface { + Hello(ctx context.Context, in *Payload, opts ...grpc.CallOption) (*Payload, error) + GetOrAssign(ctx context.Context, in *Payload, opts ...grpc.CallOption) (*Payload, error) + Mutate(ctx context.Context, in *Payload, opts ...grpc.CallOption) (*Payload, error) + ServeTask(ctx context.Context, in *Payload, opts ...grpc.CallOption) (*Payload, error) +} + +type workerClient struct { + cc *grpc.ClientConn +} + +func NewWorkerClient(cc *grpc.ClientConn) WorkerClient { + return &workerClient{cc} +} + +func (c *workerClient) Hello(ctx context.Context, in *Payload, opts ...grpc.CallOption) (*Payload, error) { + out := new(Payload) + err := grpc.Invoke(ctx, "/worker.Worker/Hello", in, out, c.cc, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *workerClient) GetOrAssign(ctx context.Context, in *Payload, opts ...grpc.CallOption) (*Payload, error) { + out := new(Payload) + err := grpc.Invoke(ctx, "/worker.Worker/GetOrAssign", in, out, c.cc, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *workerClient) Mutate(ctx context.Context, in *Payload, opts ...grpc.CallOption) (*Payload, error) { + out := new(Payload) + err := grpc.Invoke(ctx, "/worker.Worker/Mutate", in, out, c.cc, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *workerClient) ServeTask(ctx context.Context, in *Payload, opts ...grpc.CallOption) (*Payload, error) { + out := new(Payload) + err := grpc.Invoke(ctx, "/worker.Worker/ServeTask", in, out, c.cc, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +// Server API for Worker service + +type WorkerServer interface { + Hello(context.Context, *Payload) (*Payload, error) + GetOrAssign(context.Context, *Payload) (*Payload, error) + Mutate(context.Context, *Payload) (*Payload, error) + ServeTask(context.Context, *Payload) (*Payload, error) +} + +func RegisterWorkerServer(s *grpc.Server, srv WorkerServer) { + s.RegisterService(&_Worker_serviceDesc, srv) +} + +func _Worker_Hello_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(Payload) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(WorkerServer).Hello(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/worker.Worker/Hello", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(WorkerServer).Hello(ctx, req.(*Payload)) + } + return interceptor(ctx, in, info, handler) +} + +func _Worker_GetOrAssign_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(Payload) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(WorkerServer).GetOrAssign(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/worker.Worker/GetOrAssign", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(WorkerServer).GetOrAssign(ctx, req.(*Payload)) + } + return interceptor(ctx, in, info, handler) +} + +func _Worker_Mutate_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(Payload) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(WorkerServer).Mutate(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/worker.Worker/Mutate", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(WorkerServer).Mutate(ctx, req.(*Payload)) + } + return interceptor(ctx, in, info, handler) +} + +func _Worker_ServeTask_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(Payload) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(WorkerServer).ServeTask(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/worker.Worker/ServeTask", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(WorkerServer).ServeTask(ctx, req.(*Payload)) + } + return interceptor(ctx, in, info, handler) +} + +var _Worker_serviceDesc = grpc.ServiceDesc{ + ServiceName: "worker.Worker", + HandlerType: (*WorkerServer)(nil), + Methods: []grpc.MethodDesc{ + { + MethodName: "Hello", + Handler: _Worker_Hello_Handler, + }, + { + MethodName: "GetOrAssign", + Handler: _Worker_GetOrAssign_Handler, + }, + { + MethodName: "Mutate", + Handler: _Worker_Mutate_Handler, + }, + { + MethodName: "ServeTask", + Handler: _Worker_ServeTask_Handler, + }, + }, + Streams: []grpc.StreamDesc{}, +} + +var fileDescriptor0 = []byte{ + // 151 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x09, 0x6e, 0x88, 0x02, 0xff, 0xe2, 0xe2, 0x2d, 0x48, 0xac, 0xcc, + 0xc9, 0x4f, 0x4c, 0xd1, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0x62, 0x2b, 0xcf, 0x2f, 0xca, 0x4e, + 0x2d, 0x52, 0x92, 0xe5, 0x62, 0x0f, 0x80, 0x48, 0x08, 0x09, 0x71, 0xb1, 0xb8, 0x24, 0x96, 0x24, + 0x4a, 0x30, 0x2a, 0x30, 0x6a, 0xf0, 0x04, 0xb1, 0xa4, 0x00, 0xd9, 0x46, 0xc7, 0x19, 0xb9, 0xd8, + 0xc2, 0xc1, 0x2a, 0x85, 0xb4, 0xb9, 0x58, 0x3d, 0x52, 0x73, 0x72, 0xf2, 0x85, 0xf8, 0xf5, 0x20, + 0x7a, 0xf5, 0xa0, 0x1a, 0xa5, 0xd0, 0x05, 0x94, 0x18, 0x84, 0x0c, 0xb9, 0xb8, 0xdd, 0x53, 0x4b, + 0xfc, 0x8b, 0x1c, 0x8b, 0x8b, 0x33, 0xd3, 0xf3, 0x88, 0xd2, 0xa2, 0xc3, 0xc5, 0xe6, 0x5b, 0x5a, + 0x92, 0x58, 0x92, 0x4a, 0x94, 0x6a, 0x7d, 0x2e, 0xce, 0xe0, 0xd4, 0xa2, 0xb2, 0xd4, 0x90, 0xc4, + 0xe2, 0x6c, 0x62, 0x34, 0x24, 0xb1, 0x81, 0xfd, 0x6d, 0x0c, 0x08, 0x00, 0x00, 0xff, 0xff, 0x53, + 0x6e, 0x7b, 0x2e, 0x08, 0x01, 0x00, 0x00, +} diff --git a/worker/payload.proto b/worker/payload.proto new file mode 100644 index 00000000..7806379f --- /dev/null +++ b/worker/payload.proto @@ -0,0 +1,14 @@ +syntax = "proto3"; + +package worker; + +message Payload { + bytes Data = 1; +} + +service Worker { + rpc Hello (Payload) returns (Payload) {} + rpc GetOrAssign (Payload) returns (Payload) {} + rpc Mutate (Payload) returns (Payload) {} + rpc ServeTask (Payload) returns (Payload) {} +} diff --git a/worker/task.go b/worker/task.go index 358fe322..039e0053 100644 --- a/worker/task.go +++ b/worker/task.go @@ -17,12 +17,12 @@ package worker import ( - "github.com/dgraph-io/dgraph/conn" "github.com/dgraph-io/dgraph/posting" "github.com/dgraph-io/dgraph/task" "github.com/dgraph-io/dgraph/x" "github.com/dgryski/go-farm" "github.com/google/flatbuffers/go" + "golang.org/x/net/context" ) func ProcessTaskOverNetwork(qu []byte) (result []byte, rerr error) { @@ -52,13 +52,21 @@ func ProcessTaskOverNetwork(qu []byte) (result []byte, rerr error) { pool := pools[idx] addr := pool.Addr - query := new(conn.Query) + query := new(Payload) query.Data = qu - reply := new(conn.Reply) - if err := pool.Call("Worker.ServeTask", query, reply); err != nil { + + conn, err := pool.Get() + if err != nil { + return []byte(""), err + } + defer pool.Put(conn) + c := NewWorkerClient(conn) + reply, err := c.ServeTask(context.Background(), query) + if err != nil { glog.WithField("call", "Worker.ServeTask").Error(err) return []byte(""), err } + glog.WithField("reply_len", len(reply.Data)).WithField("addr", addr). WithField("attr", attr).Info("Got reply from server") return reply.Data, nil diff --git a/worker/worker.go b/worker/worker.go index b163cfdd..549fa5c7 100644 --- a/worker/worker.go +++ b/worker/worker.go @@ -18,11 +18,12 @@ package worker import ( "flag" - "io" "net" - "net/rpc" - "github.com/dgraph-io/dgraph/conn" + "google.golang.org/grpc" + + "golang.org/x/net/context" + "github.com/dgraph-io/dgraph/store" "github.com/dgraph-io/dgraph/task" "github.com/dgraph-io/dgraph/x" @@ -35,7 +36,7 @@ var workerPort = flag.String("workerport", ":12345", var glog = x.Log("worker") var dataStore, uidStore *store.Store -var pools []*conn.Pool +var pools []*Pool var numInstances, instanceIdx uint64 func Init(ps, uStore *store.Store, idx, numInst uint64) { @@ -45,39 +46,6 @@ func Init(ps, uStore *store.Store, idx, numInst uint64) { numInstances = numInst } -func Connect(workerList []string) { - w := new(Worker) - if err := rpc.Register(w); err != nil { - glog.Fatal(err) - } - if err := runServer(*workerPort); err != nil { - glog.Fatal(err) - } - if uint64(len(workerList)) != numInstances { - glog.WithField("len(list)", len(workerList)). - WithField("numInstances", numInstances). - Fatalf("Wrong number of instances in workerList") - } - - for _, addr := range workerList { - if len(addr) == 0 { - continue - } - pool := conn.NewPool(addr, 5) - query := new(conn.Query) - query.Data = []byte("hello") - reply := new(conn.Reply) - if err := pool.Call("Worker.Hello", query, reply); err != nil { - glog.WithField("call", "Worker.Hello").Fatal(err) - } - glog.WithField("reply", string(reply.Data)).WithField("addr", addr). - Info("Got reply from server") - pools = append(pools, pool) - } - - glog.Info("Server started. Clients connected.") -} - func NewQuery(attr string, uids []uint64) []byte { b := flatbuffers.NewBuilder(0) task.QueryStartUidsVector(b, len(uids)) @@ -95,21 +63,20 @@ func NewQuery(attr string, uids []uint64) []byte { return b.Bytes[b.Head():] } -type Worker struct { -} +type worker struct{} -func (w *Worker) Hello(query *conn.Query, reply *conn.Reply) error { - if string(query.Data) == "hello" { - reply.Data = []byte("Oh hello there!") +func (w *worker) Hello(ctx context.Context, in *Payload) (*Payload, error) { + out := new(Payload) + if string(in.Data) == "hello" { + out.Data = []byte("Oh hello there!") } else { - reply.Data = []byte("Hey stranger!") + out.Data = []byte("Hey stranger!") } - return nil -} -func (w *Worker) GetOrAssign(query *conn.Query, - reply *conn.Reply) (rerr error) { + return out, nil +} +func (w *worker) GetOrAssign(ctx context.Context, query *Payload) (*Payload, error) { uo := flatbuffers.GetUOffsetT(query.Data) xids := new(task.XidList) xids.Init(query.Data, uo) @@ -119,25 +86,31 @@ func (w *Worker) GetOrAssign(query *conn.Query, WithField("GetOrAssign", true). Fatal("We shouldn't be receiving this request.") } + + reply := new(Payload) + var rerr error reply.Data, rerr = getOrAssignUids(xids) - return + return reply, rerr } -func (w *Worker) Mutate(query *conn.Query, reply *conn.Reply) (rerr error) { +func (w *worker) Mutate(ctx context.Context, query *Payload) (*Payload, error) { m := new(Mutations) if err := m.Decode(query.Data); err != nil { - return err + return nil, err } left := new(Mutations) if err := mutate(m, left); err != nil { - return err + return nil, err } + + reply := new(Payload) + var rerr error reply.Data, rerr = left.Encode() - return + return reply, rerr } -func (w *Worker) ServeTask(query *conn.Query, reply *conn.Reply) (rerr error) { +func (w *worker) ServeTask(ctx context.Context, query *Payload) (*Payload, error) { uo := flatbuffers.GetUOffsetT(query.Data) q := new(task.Query) q.Init(query.Data, uo) @@ -145,51 +118,62 @@ func (w *Worker) ServeTask(query *conn.Query, reply *conn.Reply) (rerr error) { glog.WithField("attr", attr).WithField("num_uids", q.UidsLength()). WithField("instanceIdx", instanceIdx).Info("ServeTask") + reply := new(Payload) + var rerr error if (instanceIdx == 0 && attr == "_xid_") || farm.Fingerprint64([]byte(attr))%numInstances == instanceIdx { reply.Data, rerr = processTask(query.Data) + } else { glog.WithField("attribute", attr). WithField("instanceIdx", instanceIdx). Fatalf("Request sent to wrong server") } - return rerr + return reply, rerr } -func serveRequests(irwc io.ReadWriteCloser) { - for { - sc := &conn.ServerCodec{ - Rwc: irwc, - } - glog.Info("Serving request from serveRequests") - if err := rpc.ServeRequest(sc); err != nil { - glog.WithField("method", "serveRequests").Info(err) - break - } - } -} - -func runServer(address string) error { - ln, err := net.Listen("tcp", address) +func runServer(port string) { + ln, err := net.Listen("tcp", port) if err != nil { glog.Fatalf("While running server: %v", err) - return err + return } glog.WithField("address", ln.Addr()).Info("Worker listening") - go func() { - for { - cxn, err := ln.Accept() - if err != nil { - glog.Fatalf("listen(%q): %s\n", address, err) - return - } - glog.WithField("local", cxn.LocalAddr()). - WithField("remote", cxn.RemoteAddr()). - Debug("Worker accepted connection") - go serveRequests(cxn) + s := grpc.NewServer(grpc.CustomCodec(&PayloadCodec{})) + RegisterWorkerServer(s, &worker{}) + s.Serve(ln) +} + +func Connect(workerList []string) { + go runServer(*workerPort) + + for _, addr := range workerList { + if len(addr) == 0 { + continue + } + + pool := NewPool(addr, 5) + query := new(Payload) + query.Data = []byte("hello") + + conn, err := pool.Get() + if err != nil { + glog.WithError(err).Fatal("Unable to connect.") } - }() - return nil + + c := NewWorkerClient(conn) + reply, err := c.Hello(context.Background(), query) + if err != nil { + glog.WithError(err).Fatal("Unable to contact.") + } + _ = pool.Put(conn) + + glog.WithField("reply", string(reply.Data)).WithField("addr", addr). + Info("Got reply from server") + pools = append(pools, pool) + } + + glog.Info("Server started. Clients connected.") } -- GitLab