diff --git a/conn/client.go b/conn/client.go deleted file mode 100644 index 47bee5180843093514d2c40525990f76d6cfa5e1..0000000000000000000000000000000000000000 --- a/conn/client.go +++ /dev/null @@ -1,69 +0,0 @@ -/* - * Copyright 2016 DGraph Labs, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package conn - -import ( - "errors" - "fmt" - "io" - "log" - "net/rpc" -) - -type ClientCodec struct { - Rwc io.ReadWriteCloser - payloadLen int32 -} - -func (c *ClientCodec) WriteRequest(r *rpc.Request, body interface{}) error { - if body == nil { - return fmt.Errorf("Nil request body from client.") - } - - query := body.(*Query) - if err := writeHeader(c.Rwc, r.Seq, r.ServiceMethod, query.Data); err != nil { - return err - } - n, err := c.Rwc.Write(query.Data) - if n != len(query.Data) { - return errors.New("Unable to write payload.") - } - return err -} - -func (c *ClientCodec) ReadResponseHeader(r *rpc.Response) error { - if len(r.Error) > 0 { - log.Fatal("client got response error: " + r.Error) - } - if err := parseHeader(c.Rwc, &r.Seq, - &r.ServiceMethod, &c.payloadLen); err != nil { - return err - } - return nil -} - -func (c *ClientCodec) ReadResponseBody(body interface{}) error { - buf := make([]byte, c.payloadLen) - _, err := io.ReadFull(c.Rwc, buf) - reply := body.(*Reply) - reply.Data = buf - return err -} - -func (c *ClientCodec) Close() error { - return c.Rwc.Close() -} diff --git a/conn/codec.go b/conn/codec.go deleted file mode 100644 index ec921739fceb35e0f213bcc79a354a054989325c..0000000000000000000000000000000000000000 --- a/conn/codec.go +++ /dev/null @@ -1,78 +0,0 @@ -/* - * Copyright 2016 DGraph Labs, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package conn - -import ( - "bytes" - "encoding/binary" - "fmt" - "io" - - "github.com/dgraph-io/dgraph/x" -) - -type Query struct { - Data []byte -} - -type Reply struct { - Data []byte - // TODO(manishrjain): Add an error here. - // Error string -} - -func writeHeader(rwc io.ReadWriteCloser, seq uint64, - method string, data []byte) error { - - var bh bytes.Buffer - var rerr error - - x.SetError(&rerr, binary.Write(&bh, binary.LittleEndian, seq)) - x.SetError(&rerr, binary.Write(&bh, binary.LittleEndian, int32(len(method)))) - x.SetError(&rerr, binary.Write(&bh, binary.LittleEndian, int32(len(data)))) - _, err := bh.Write([]byte(method)) - x.SetError(&rerr, err) - if rerr != nil { - return rerr - } - _, err = rwc.Write(bh.Bytes()) - return err -} - -func parseHeader(rwc io.ReadWriteCloser, seq *uint64, - method *string, plen *int32) error { - - var err error - var sz int32 - x.SetError(&err, binary.Read(rwc, binary.LittleEndian, seq)) - x.SetError(&err, binary.Read(rwc, binary.LittleEndian, &sz)) - x.SetError(&err, binary.Read(rwc, binary.LittleEndian, plen)) - if err != nil { - return err - } - - buf := make([]byte, sz) - n, err := rwc.Read(buf) - if err != nil { - return err - } - if n != int(sz) { - return fmt.Errorf("Expected: %v. Got: %v\n", sz, n) - } - *method = string(buf) - return nil -} diff --git a/conn/codec_test.go b/conn/codec_test.go deleted file mode 100644 index d8ff4ebaf3573e04478538b32f0ef10985136716..0000000000000000000000000000000000000000 --- a/conn/codec_test.go +++ /dev/null @@ -1,161 +0,0 @@ -/* - * Copyright 2016 DGraph Labs, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package conn - -import ( - "bytes" - "net/rpc" - "testing" -) - -type buf struct { - data chan byte -} - -func newBuf() *buf { - b := new(buf) - b.data = make(chan byte, 10000) - return b -} - -func (b *buf) Read(p []byte) (n int, err error) { - for i := 0; i < len(p); i++ { - p[i] = <-b.data - } - return len(p), nil -} - -func (b *buf) Write(p []byte) (n int, err error) { - for i := 0; i < len(p); i++ { - b.data <- p[i] - } - return len(p), nil -} - -func (b *buf) Close() error { - close(b.data) - return nil -} - -func TestWriteAndParseHeader(t *testing.T) { - b := newBuf() - data := []byte("oh hey") - if err := writeHeader(b, 11, "testing.T", data); err != nil { - t.Error(err) - t.Fail() - } - var seq uint64 - var method string - var plen int32 - if err := parseHeader(b, &seq, &method, &plen); err != nil { - t.Error(err) - t.Fail() - } - if seq != 11 { - t.Errorf("Sequence number. Expected 11. Got: %v", seq) - } - if method != "testing.T" { - t.Errorf("Method name. Expected: testing.T. Got: %v", method) - } - if plen != int32(len(data)) { - t.Errorf("Payload length. Expected: %v. Got: %v", len(data), plen) - } -} - -func TestClientToServer(t *testing.T) { - b := newBuf() - cc := &ClientCodec{ - Rwc: b, - } - sc := &ServerCodec{ - Rwc: b, - } - - r := &rpc.Request{ - ServiceMethod: "Test.ClientServer", - Seq: 11, - } - - query := new(Query) - query.Data = []byte("iamaquery") - if err := cc.WriteRequest(r, query); err != nil { - t.Error(err) - } - - sr := new(rpc.Request) - if err := sc.ReadRequestHeader(sr); err != nil { - t.Error(err) - } - if sr.Seq != r.Seq { - t.Errorf("RPC Seq. Expected: %v. Got: %v", r.Seq, sr.Seq) - } - if sr.ServiceMethod != r.ServiceMethod { - t.Errorf("ServiceMethod. Expected: %v. Got: %v", - r.ServiceMethod, sr.ServiceMethod) - } - - squery := new(Query) - if err := sc.ReadRequestBody(squery); err != nil { - t.Error(err) - } - if !bytes.Equal(squery.Data, query.Data) { - t.Errorf("Queries don't match. Expected: %v Got: %v", - string(query.Data), string(squery.Data)) - } -} - -func TestServerToClient(t *testing.T) { - b := newBuf() - cc := &ClientCodec{ - Rwc: b, - } - sc := &ServerCodec{ - Rwc: b, - } - - r := &rpc.Response{ - ServiceMethod: "Test.ClientServer", - Seq: 11, - } - - reply := new(Reply) - reply.Data = []byte("iamareply") - if err := sc.WriteResponse(r, reply); err != nil { - t.Error(err) - } - - cr := new(rpc.Response) - if err := cc.ReadResponseHeader(cr); err != nil { - t.Error(err) - } - if cr.Seq != r.Seq { - t.Errorf("RPC Seq. Expected: %v. Got: %v", r.Seq, cr.Seq) - } - if cr.ServiceMethod != r.ServiceMethod { - t.Errorf("ServiceMethod. Expected: %v. Got: %v", - r.ServiceMethod, cr.ServiceMethod) - } - - creply := new(Reply) - if err := cc.ReadResponseBody(creply); err != nil { - t.Error(err) - } - if !bytes.Equal(creply.Data, reply.Data) { - t.Errorf("Replies don't match. Expected: %v Got: %v", - string(reply.Data), string(creply.Data)) - } -} diff --git a/conn/pool.go b/conn/pool.go deleted file mode 100644 index 76a107e53b19518223dbbd61ab4c7e7d18e31cfc..0000000000000000000000000000000000000000 --- a/conn/pool.go +++ /dev/null @@ -1,110 +0,0 @@ -/* - * Copyright 2016 DGraph Labs, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package conn - -import ( - "net" - "net/rpc" - "strings" - "time" - - "github.com/dgraph-io/dgraph/x" -) - -var glog = x.Log("conn") - -type Pool struct { - clients chan *rpc.Client - Addr string -} - -func NewPool(addr string, maxCap int) *Pool { - p := new(Pool) - p.Addr = addr - p.clients = make(chan *rpc.Client, maxCap) - client, err := p.dialNew() - if err != nil { - glog.Fatal(err) - return nil - } - p.clients <- client - return p -} - -func (p *Pool) dialNew() (*rpc.Client, error) { - d := &net.Dialer{ - Timeout: 3 * time.Minute, - } - var nconn net.Conn - var err error - // This loop will retry for 10 minutes before giving up. - for i := 0; i < 60; i++ { - nconn, err = d.Dial("tcp", p.Addr) - if err == nil { - break - } - if !strings.Contains(err.Error(), "refused") { - break - } - - glog.WithField("error", err).WithField("addr", p.Addr). - Info("Retrying connection...") - time.Sleep(10 * time.Second) - } - if err != nil { - return nil, err - } - cc := &ClientCodec{ - Rwc: nconn, - } - return rpc.NewClientWithCodec(cc), nil -} - -func (p *Pool) Call(serviceMethod string, args interface{}, - reply interface{}) error { - - client, err := p.get() - if err != nil { - return err - } - if err = client.Call(serviceMethod, args, reply); err != nil { - return err - } - - select { - case p.clients <- client: - return nil - default: - return client.Close() - } -} - -func (p *Pool) get() (*rpc.Client, error) { - select { - case client := <-p.clients: - return client, nil - default: - return p.dialNew() - } -} - -func (p *Pool) Close() error { - // We're not doing a clean exit here. A clean exit here would require - // synchronization, which seems unnecessary for now. But, we should - // add one if required later. - return nil -} diff --git a/conn/server.go b/conn/server.go deleted file mode 100644 index 406ea9bc8a96780d56fc9ed37d60aa04ab5278c5..0000000000000000000000000000000000000000 --- a/conn/server.go +++ /dev/null @@ -1,75 +0,0 @@ -/* - * Copyright 2016 DGraph Labs, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package conn - -import ( - "io" - "log" - "net/rpc" -) - -type ServerCodec struct { - Rwc io.ReadWriteCloser - payloadLen int32 -} - -func (c *ServerCodec) ReadRequestHeader(r *rpc.Request) error { - return parseHeader(c.Rwc, &r.Seq, &r.ServiceMethod, &c.payloadLen) -} - -func (c *ServerCodec) ReadRequestBody(data interface{}) error { - b := make([]byte, c.payloadLen) - _, err := io.ReadFull(c.Rwc, b) - if err != nil { - return err - } - - if data == nil { - // If data is nil, discard this request. - return nil - } - query := data.(*Query) - query.Data = b - return nil -} - -func (c *ServerCodec) WriteResponse(resp *rpc.Response, - data interface{}) error { - - if len(resp.Error) > 0 { - log.Fatal("Response has error: " + resp.Error) - } - if data == nil { - log.Fatal("Worker write response data is nil") - } - reply, ok := data.(*Reply) - if !ok { - log.Fatal("Unable to convert to reply") - } - - if err := writeHeader(c.Rwc, resp.Seq, - resp.ServiceMethod, reply.Data); err != nil { - return err - } - - _, err := c.Rwc.Write(reply.Data) - return err -} - -func (c *ServerCodec) Close() error { - return c.Rwc.Close() -} diff --git a/vendor/github.com/golang/protobuf/proto/clone.go b/vendor/github.com/golang/protobuf/proto/clone.go index e98ddec9815b6da6a1483127656d5018656afccd..e392575b353afa4f22f513d3f22ff64a8f5fdbf1 100644 --- a/vendor/github.com/golang/protobuf/proto/clone.go +++ b/vendor/github.com/golang/protobuf/proto/clone.go @@ -84,9 +84,15 @@ func mergeStruct(out, in reflect.Value) { mergeAny(out.Field(i), in.Field(i), false, sprop.Prop[i]) } - if emIn, ok := in.Addr().Interface().(extendableProto); ok { - emOut := out.Addr().Interface().(extendableProto) - mergeExtension(emOut.ExtensionMap(), emIn.ExtensionMap()) + if emIn, ok := extendable(in.Addr().Interface()); ok { + emOut, _ := extendable(out.Addr().Interface()) + mIn, muIn := emIn.extensionsRead() + if mIn != nil { + mOut := emOut.extensionsWrite() + muIn.Lock() + mergeExtension(mOut, mIn) + muIn.Unlock() + } } uf := in.FieldByName("XXX_unrecognized") diff --git a/vendor/github.com/golang/protobuf/proto/decode.go b/vendor/github.com/golang/protobuf/proto/decode.go index f94b9f416ecd54f6b72d1f4a6faacd43a3fb6e96..07288a250b47add2f1fa939d9951a307918360f7 100644 --- a/vendor/github.com/golang/protobuf/proto/decode.go +++ b/vendor/github.com/golang/protobuf/proto/decode.go @@ -390,11 +390,12 @@ func (o *Buffer) unmarshalType(st reflect.Type, prop *StructProperties, is_group if !ok { // Maybe it's an extension? if prop.extendable { - if e := structPointer_Interface(base, st).(extendableProto); isExtensionField(e, int32(tag)) { + if e, _ := extendable(structPointer_Interface(base, st)); isExtensionField(e, int32(tag)) { if err = o.skip(st, tag, wire); err == nil { - ext := e.ExtensionMap()[int32(tag)] // may be missing + extmap := e.extensionsWrite() + ext := extmap[int32(tag)] // may be missing ext.enc = append(ext.enc, o.buf[oi:o.index]...) - e.ExtensionMap()[int32(tag)] = ext + extmap[int32(tag)] = ext } continue } diff --git a/vendor/github.com/golang/protobuf/proto/encode.go b/vendor/github.com/golang/protobuf/proto/encode.go index eb7e0474ef6a864804c235931f57350408e9e2f0..8c1b8fd1f6831579d821a8bce8d657d597f55fb0 100644 --- a/vendor/github.com/golang/protobuf/proto/encode.go +++ b/vendor/github.com/golang/protobuf/proto/encode.go @@ -70,6 +70,10 @@ var ( // ErrNil is the error returned if Marshal is called with nil. ErrNil = errors.New("proto: Marshal called with nil") + + // ErrTooLarge is the error returned if Marshal is called with a + // message that encodes to >2GB. + ErrTooLarge = errors.New("proto: message encodes to over 2 GB") ) // The fundamental encoders that put bytes on the wire. @@ -78,6 +82,10 @@ var ( const maxVarintBytes = 10 // maximum length of a varint +// maxMarshalSize is the largest allowed size of an encoded protobuf, +// since C++ and Java use signed int32s for the size. +const maxMarshalSize = 1<<31 - 1 + // EncodeVarint returns the varint encoding of x. // This is the format for the // int32, int64, uint32, uint64, bool, and enum @@ -277,6 +285,9 @@ func (p *Buffer) Marshal(pb Message) error { stats.Encode++ } + if len(p.buf) > maxMarshalSize { + return ErrTooLarge + } return err } @@ -1062,10 +1073,25 @@ func size_slice_struct_group(p *Properties, base structPointer) (n int) { // Encode an extension map. func (o *Buffer) enc_map(p *Properties, base structPointer) error { - v := *structPointer_ExtMap(base, p.field) - if err := encodeExtensionMap(v); err != nil { + exts := structPointer_ExtMap(base, p.field) + if err := encodeExtensionsMap(*exts); err != nil { return err } + + return o.enc_map_body(*exts) +} + +func (o *Buffer) enc_exts(p *Properties, base structPointer) error { + exts := structPointer_Extensions(base, p.field) + if err := encodeExtensions(exts); err != nil { + return err + } + v, _ := exts.extensionsRead() + + return o.enc_map_body(v) +} + +func (o *Buffer) enc_map_body(v map[int32]Extension) error { // Fast-path for common cases: zero or one extensions. if len(v) <= 1 { for _, e := range v { @@ -1088,8 +1114,13 @@ func (o *Buffer) enc_map(p *Properties, base structPointer) error { } func size_map(p *Properties, base structPointer) int { - v := *structPointer_ExtMap(base, p.field) - return sizeExtensionMap(v) + v := structPointer_ExtMap(base, p.field) + return extensionsMapSize(*v) +} + +func size_exts(p *Properties, base structPointer) int { + v := structPointer_Extensions(base, p.field) + return extensionsSize(v) } // Encode a map field. @@ -1118,7 +1149,7 @@ func (o *Buffer) enc_new_map(p *Properties, base structPointer) error { if err := p.mkeyprop.enc(o, p.mkeyprop, keybase); err != nil { return err } - if err := p.mvalprop.enc(o, p.mvalprop, valbase); err != nil { + if err := p.mvalprop.enc(o, p.mvalprop, valbase); err != nil && err != ErrNil { return err } return nil @@ -1128,11 +1159,6 @@ func (o *Buffer) enc_new_map(p *Properties, base structPointer) error { for _, key := range v.MapKeys() { val := v.MapIndex(key) - // The only illegal map entry values are nil message pointers. - if val.Kind() == reflect.Ptr && val.IsNil() { - return errors.New("proto: map has nil element") - } - keycopy.Set(key) valcopy.Set(val) @@ -1220,6 +1246,9 @@ func (o *Buffer) enc_struct(prop *StructProperties, base structPointer) error { return err } } + if len(o.buf) > maxMarshalSize { + return ErrTooLarge + } } } @@ -1236,6 +1265,9 @@ func (o *Buffer) enc_struct(prop *StructProperties, base structPointer) error { // Add unrecognized fields at the end. if prop.unrecField.IsValid() { v := *structPointer_Bytes(base, prop.unrecField) + if len(o.buf)+len(v) > maxMarshalSize { + return ErrTooLarge + } if len(v) > 0 { o.buf = append(o.buf, v...) } diff --git a/vendor/github.com/golang/protobuf/proto/equal.go b/vendor/github.com/golang/protobuf/proto/equal.go index f5db1def3c2483e5fb7879bfe318dea69964147c..8b16f951c712703a03a14591510818da15557b01 100644 --- a/vendor/github.com/golang/protobuf/proto/equal.go +++ b/vendor/github.com/golang/protobuf/proto/equal.go @@ -121,9 +121,16 @@ func equalStruct(v1, v2 reflect.Value) bool { } } + if em1 := v1.FieldByName("XXX_InternalExtensions"); em1.IsValid() { + em2 := v2.FieldByName("XXX_InternalExtensions") + if !equalExtensions(v1.Type(), em1.Interface().(XXX_InternalExtensions), em2.Interface().(XXX_InternalExtensions)) { + return false + } + } + if em1 := v1.FieldByName("XXX_extensions"); em1.IsValid() { em2 := v2.FieldByName("XXX_extensions") - if !equalExtensions(v1.Type(), em1.Interface().(map[int32]Extension), em2.Interface().(map[int32]Extension)) { + if !equalExtMap(v1.Type(), em1.Interface().(map[int32]Extension), em2.Interface().(map[int32]Extension)) { return false } } @@ -184,6 +191,13 @@ func equalAny(v1, v2 reflect.Value, prop *Properties) bool { } return true case reflect.Ptr: + // Maps may have nil values in them, so check for nil. + if v1.IsNil() && v2.IsNil() { + return true + } + if v1.IsNil() != v2.IsNil() { + return false + } return equalAny(v1.Elem(), v2.Elem(), prop) case reflect.Slice: if v1.Type().Elem().Kind() == reflect.Uint8 { @@ -223,8 +237,14 @@ func equalAny(v1, v2 reflect.Value, prop *Properties) bool { } // base is the struct type that the extensions are based on. -// em1 and em2 are extension maps. -func equalExtensions(base reflect.Type, em1, em2 map[int32]Extension) bool { +// x1 and x2 are InternalExtensions. +func equalExtensions(base reflect.Type, x1, x2 XXX_InternalExtensions) bool { + em1, _ := x1.extensionsRead() + em2, _ := x2.extensionsRead() + return equalExtMap(base, em1, em2) +} + +func equalExtMap(base reflect.Type, em1, em2 map[int32]Extension) bool { if len(em1) != len(em2) { return false } diff --git a/vendor/github.com/golang/protobuf/proto/extensions.go b/vendor/github.com/golang/protobuf/proto/extensions.go index 054f4f1df78872e3e9a5ec2ad41a937e463631f5..9f484f53a3ae8e9262820ba8145b14bb8b37dfd0 100644 --- a/vendor/github.com/golang/protobuf/proto/extensions.go +++ b/vendor/github.com/golang/protobuf/proto/extensions.go @@ -52,14 +52,99 @@ type ExtensionRange struct { Start, End int32 // both inclusive } -// extendableProto is an interface implemented by any protocol buffer that may be extended. +// extendableProto is an interface implemented by any protocol buffer generated by the current +// proto compiler that may be extended. type extendableProto interface { + Message + ExtensionRangeArray() []ExtensionRange + extensionsWrite() map[int32]Extension + extensionsRead() (map[int32]Extension, sync.Locker) +} + +// extendableProtoV1 is an interface implemented by a protocol buffer generated by the previous +// version of the proto compiler that may be extended. +type extendableProtoV1 interface { Message ExtensionRangeArray() []ExtensionRange ExtensionMap() map[int32]Extension } +// extensionAdapter is a wrapper around extendableProtoV1 that implements extendableProto. +type extensionAdapter struct { + extendableProtoV1 +} + +func (e extensionAdapter) extensionsWrite() map[int32]Extension { + return e.ExtensionMap() +} + +func (e extensionAdapter) extensionsRead() (map[int32]Extension, sync.Locker) { + return e.ExtensionMap(), notLocker{} +} + +// notLocker is a sync.Locker whose Lock and Unlock methods are nops. +type notLocker struct{} + +func (n notLocker) Lock() {} +func (n notLocker) Unlock() {} + +// extendable returns the extendableProto interface for the given generated proto message. +// If the proto message has the old extension format, it returns a wrapper that implements +// the extendableProto interface. +func extendable(p interface{}) (extendableProto, bool) { + if ep, ok := p.(extendableProto); ok { + return ep, ok + } + if ep, ok := p.(extendableProtoV1); ok { + return extensionAdapter{ep}, ok + } + return nil, false +} + +// XXX_InternalExtensions is an internal representation of proto extensions. +// +// Each generated message struct type embeds an anonymous XXX_InternalExtensions field, +// thus gaining the unexported 'extensions' method, which can be called only from the proto package. +// +// The methods of XXX_InternalExtensions are not concurrency safe in general, +// but calls to logically read-only methods such as has and get may be executed concurrently. +type XXX_InternalExtensions struct { + // The struct must be indirect so that if a user inadvertently copies a + // generated message and its embedded XXX_InternalExtensions, they + // avoid the mayhem of a copied mutex. + // + // The mutex serializes all logically read-only operations to p.extensionMap. + // It is up to the client to ensure that write operations to p.extensionMap are + // mutually exclusive with other accesses. + p *struct { + mu sync.Mutex + extensionMap map[int32]Extension + } +} + +// extensionsWrite returns the extension map, creating it on first use. +func (e *XXX_InternalExtensions) extensionsWrite() map[int32]Extension { + if e.p == nil { + e.p = new(struct { + mu sync.Mutex + extensionMap map[int32]Extension + }) + e.p.extensionMap = make(map[int32]Extension) + } + return e.p.extensionMap +} + +// extensionsRead returns the extensions map for read-only use. It may be nil. +// The caller must hold the returned mutex's lock when accessing Elements within the map. +func (e *XXX_InternalExtensions) extensionsRead() (map[int32]Extension, sync.Locker) { + if e.p == nil { + return nil, nil + } + return e.p.extensionMap, &e.p.mu +} + var extendableProtoType = reflect.TypeOf((*extendableProto)(nil)).Elem() +var extendableProtoV1Type = reflect.TypeOf((*extendableProtoV1)(nil)).Elem() // ExtensionDesc represents an extension specification. // Used in generated code from the protocol compiler. @@ -92,8 +177,13 @@ type Extension struct { } // SetRawExtension is for testing only. -func SetRawExtension(base extendableProto, id int32, b []byte) { - base.ExtensionMap()[id] = Extension{enc: b} +func SetRawExtension(base Message, id int32, b []byte) { + epb, ok := extendable(base) + if !ok { + return + } + extmap := epb.extensionsWrite() + extmap[id] = Extension{enc: b} } // isExtensionField returns true iff the given field number is in an extension range. @@ -108,8 +198,12 @@ func isExtensionField(pb extendableProto, field int32) bool { // checkExtensionTypes checks that the given extension is valid for pb. func checkExtensionTypes(pb extendableProto, extension *ExtensionDesc) error { + var pbi interface{} = pb // Check the extended type. - if a, b := reflect.TypeOf(pb), reflect.TypeOf(extension.ExtendedType); a != b { + if ea, ok := pbi.(extensionAdapter); ok { + pbi = ea.extendableProtoV1 + } + if a, b := reflect.TypeOf(pbi), reflect.TypeOf(extension.ExtendedType); a != b { return errors.New("proto: bad extended type; " + b.String() + " does not extend " + a.String()) } // Check the range. @@ -155,8 +249,19 @@ func extensionProperties(ed *ExtensionDesc) *Properties { return prop } -// encodeExtensionMap encodes any unmarshaled (unencoded) extensions in m. -func encodeExtensionMap(m map[int32]Extension) error { +// encode encodes any unmarshaled (unencoded) extensions in e. +func encodeExtensions(e *XXX_InternalExtensions) error { + m, mu := e.extensionsRead() + if m == nil { + return nil // fast path + } + mu.Lock() + defer mu.Unlock() + return encodeExtensionsMap(m) +} + +// encode encodes any unmarshaled (unencoded) extensions in e. +func encodeExtensionsMap(m map[int32]Extension) error { for k, e := range m { if e.value == nil || e.desc == nil { // Extension is only in its encoded form. @@ -184,7 +289,17 @@ func encodeExtensionMap(m map[int32]Extension) error { return nil } -func sizeExtensionMap(m map[int32]Extension) (n int) { +func extensionsSize(e *XXX_InternalExtensions) (n int) { + m, mu := e.extensionsRead() + if m == nil { + return 0 + } + mu.Lock() + defer mu.Unlock() + return extensionsMapSize(m) +} + +func extensionsMapSize(m map[int32]Extension) (n int) { for _, e := range m { if e.value == nil || e.desc == nil { // Extension is only in its encoded form. @@ -209,26 +324,51 @@ func sizeExtensionMap(m map[int32]Extension) (n int) { } // HasExtension returns whether the given extension is present in pb. -func HasExtension(pb extendableProto, extension *ExtensionDesc) bool { +func HasExtension(pb Message, extension *ExtensionDesc) bool { // TODO: Check types, field numbers, etc.? - _, ok := pb.ExtensionMap()[extension.Field] + epb, ok := extendable(pb) + if !ok { + return false + } + extmap, mu := epb.extensionsRead() + if extmap == nil { + return false + } + mu.Lock() + _, ok = extmap[extension.Field] + mu.Unlock() return ok } // ClearExtension removes the given extension from pb. -func ClearExtension(pb extendableProto, extension *ExtensionDesc) { +func ClearExtension(pb Message, extension *ExtensionDesc) { + epb, ok := extendable(pb) + if !ok { + return + } // TODO: Check types, field numbers, etc.? - delete(pb.ExtensionMap(), extension.Field) + extmap := epb.extensionsWrite() + delete(extmap, extension.Field) } // GetExtension parses and returns the given extension of pb. // If the extension is not present and has no default value it returns ErrMissingExtension. -func GetExtension(pb extendableProto, extension *ExtensionDesc) (interface{}, error) { - if err := checkExtensionTypes(pb, extension); err != nil { +func GetExtension(pb Message, extension *ExtensionDesc) (interface{}, error) { + epb, ok := extendable(pb) + if !ok { + return nil, errors.New("proto: not an extendable proto") + } + + if err := checkExtensionTypes(epb, extension); err != nil { return nil, err } - emap := pb.ExtensionMap() + emap, mu := epb.extensionsRead() + if emap == nil { + return defaultExtensionValue(extension) + } + mu.Lock() + defer mu.Unlock() e, ok := emap[extension.Field] if !ok { // defaultExtensionValue returns the default value or @@ -332,10 +472,9 @@ func decodeExtension(b []byte, extension *ExtensionDesc) (interface{}, error) { // GetExtensions returns a slice of the extensions present in pb that are also listed in es. // The returned slice has the same length as es; missing extensions will appear as nil elements. func GetExtensions(pb Message, es []*ExtensionDesc) (extensions []interface{}, err error) { - epb, ok := pb.(extendableProto) + epb, ok := extendable(pb) if !ok { - err = errors.New("proto: not an extendable proto") - return + return nil, errors.New("proto: not an extendable proto") } extensions = make([]interface{}, len(es)) for i, e := range es { @@ -351,8 +490,12 @@ func GetExtensions(pb Message, es []*ExtensionDesc) (extensions []interface{}, e } // SetExtension sets the specified extension of pb to the specified value. -func SetExtension(pb extendableProto, extension *ExtensionDesc, value interface{}) error { - if err := checkExtensionTypes(pb, extension); err != nil { +func SetExtension(pb Message, extension *ExtensionDesc, value interface{}) error { + epb, ok := extendable(pb) + if !ok { + return errors.New("proto: not an extendable proto") + } + if err := checkExtensionTypes(epb, extension); err != nil { return err } typ := reflect.TypeOf(extension.ExtensionType) @@ -368,10 +511,23 @@ func SetExtension(pb extendableProto, extension *ExtensionDesc, value interface{ return fmt.Errorf("proto: SetExtension called with nil value of type %T", value) } - pb.ExtensionMap()[extension.Field] = Extension{desc: extension, value: value} + extmap := epb.extensionsWrite() + extmap[extension.Field] = Extension{desc: extension, value: value} return nil } +// ClearAllExtensions clears all extensions from pb. +func ClearAllExtensions(pb Message) { + epb, ok := extendable(pb) + if !ok { + return + } + m := epb.extensionsWrite() + for k := range m { + delete(m, k) + } +} + // A global registry of extensions. // The generated code will register the generated descriptors by calling RegisterExtension. diff --git a/vendor/github.com/golang/protobuf/proto/lib.go b/vendor/github.com/golang/protobuf/proto/lib.go index 0de8f8dffd0880ceea9abd8176eaa2adb1e11109..170b8e87d2e460bc3c8427af849bf4d5978e79b7 100644 --- a/vendor/github.com/golang/protobuf/proto/lib.go +++ b/vendor/github.com/golang/protobuf/proto/lib.go @@ -889,6 +889,10 @@ func isProto3Zero(v reflect.Value) bool { return false } +// ProtoPackageIsVersion2 is referenced from generated protocol buffer files +// to assert that that code is compatible with this version of the proto package. +const ProtoPackageIsVersion2 = true + // ProtoPackageIsVersion1 is referenced from generated protocol buffer files // to assert that that code is compatible with this version of the proto package. const ProtoPackageIsVersion1 = true diff --git a/vendor/github.com/golang/protobuf/proto/message_set.go b/vendor/github.com/golang/protobuf/proto/message_set.go index e25e01e637483854ff51d528627d83065d525827..fd982decd66e4846031a72a785470be20afe99a5 100644 --- a/vendor/github.com/golang/protobuf/proto/message_set.go +++ b/vendor/github.com/golang/protobuf/proto/message_set.go @@ -149,9 +149,21 @@ func skipVarint(buf []byte) []byte { // MarshalMessageSet encodes the extension map represented by m in the message set wire format. // It is called by generated Marshal methods on protocol buffer messages with the message_set_wire_format option. -func MarshalMessageSet(m map[int32]Extension) ([]byte, error) { - if err := encodeExtensionMap(m); err != nil { - return nil, err +func MarshalMessageSet(exts interface{}) ([]byte, error) { + var m map[int32]Extension + switch exts := exts.(type) { + case *XXX_InternalExtensions: + if err := encodeExtensions(exts); err != nil { + return nil, err + } + m, _ = exts.extensionsRead() + case map[int32]Extension: + if err := encodeExtensionsMap(exts); err != nil { + return nil, err + } + m = exts + default: + return nil, errors.New("proto: not an extension map") } // Sort extension IDs to provide a deterministic encoding. @@ -178,7 +190,17 @@ func MarshalMessageSet(m map[int32]Extension) ([]byte, error) { // UnmarshalMessageSet decodes the extension map encoded in buf in the message set wire format. // It is called by generated Unmarshal methods on protocol buffer messages with the message_set_wire_format option. -func UnmarshalMessageSet(buf []byte, m map[int32]Extension) error { +func UnmarshalMessageSet(buf []byte, exts interface{}) error { + var m map[int32]Extension + switch exts := exts.(type) { + case *XXX_InternalExtensions: + m = exts.extensionsWrite() + case map[int32]Extension: + m = exts + default: + return errors.New("proto: not an extension map") + } + ms := new(messageSet) if err := Unmarshal(buf, ms); err != nil { return err @@ -209,7 +231,16 @@ func UnmarshalMessageSet(buf []byte, m map[int32]Extension) error { // MarshalMessageSetJSON encodes the extension map represented by m in JSON format. // It is called by generated MarshalJSON methods on protocol buffer messages with the message_set_wire_format option. -func MarshalMessageSetJSON(m map[int32]Extension) ([]byte, error) { +func MarshalMessageSetJSON(exts interface{}) ([]byte, error) { + var m map[int32]Extension + switch exts := exts.(type) { + case *XXX_InternalExtensions: + m, _ = exts.extensionsRead() + case map[int32]Extension: + m = exts + default: + return nil, errors.New("proto: not an extension map") + } var b bytes.Buffer b.WriteByte('{') @@ -252,7 +283,7 @@ func MarshalMessageSetJSON(m map[int32]Extension) ([]byte, error) { // UnmarshalMessageSetJSON decodes the extension map encoded in buf in JSON format. // It is called by generated UnmarshalJSON methods on protocol buffer messages with the message_set_wire_format option. -func UnmarshalMessageSetJSON(buf []byte, m map[int32]Extension) error { +func UnmarshalMessageSetJSON(buf []byte, exts interface{}) error { // Common-case fast path. if len(buf) == 0 || bytes.Equal(buf, []byte("{}")) { return nil diff --git a/vendor/github.com/golang/protobuf/proto/pointer_reflect.go b/vendor/github.com/golang/protobuf/proto/pointer_reflect.go index 989914177d0ccce03dbceee43bca5f478594a171..fb512e2e16dce05683722f810c279367bdc68fe9 100644 --- a/vendor/github.com/golang/protobuf/proto/pointer_reflect.go +++ b/vendor/github.com/golang/protobuf/proto/pointer_reflect.go @@ -139,6 +139,11 @@ func structPointer_StringSlice(p structPointer, f field) *[]string { return structPointer_ifield(p, f).(*[]string) } +// Extensions returns the address of an extension map field in the struct. +func structPointer_Extensions(p structPointer, f field) *XXX_InternalExtensions { + return structPointer_ifield(p, f).(*XXX_InternalExtensions) +} + // ExtMap returns the address of an extension map field in the struct. func structPointer_ExtMap(p structPointer, f field) *map[int32]Extension { return structPointer_ifield(p, f).(*map[int32]Extension) diff --git a/vendor/github.com/golang/protobuf/proto/pointer_unsafe.go b/vendor/github.com/golang/protobuf/proto/pointer_unsafe.go index ceece772a2d0e5724e93ffef35be94dd1eb573eb..6b5567d47cd396b25370f8c06bad3b851776658f 100644 --- a/vendor/github.com/golang/protobuf/proto/pointer_unsafe.go +++ b/vendor/github.com/golang/protobuf/proto/pointer_unsafe.go @@ -126,6 +126,10 @@ func structPointer_StringSlice(p structPointer, f field) *[]string { } // ExtMap returns the address of an extension map field in the struct. +func structPointer_Extensions(p structPointer, f field) *XXX_InternalExtensions { + return (*XXX_InternalExtensions)(unsafe.Pointer(uintptr(p) + uintptr(f))) +} + func structPointer_ExtMap(p structPointer, f field) *map[int32]Extension { return (*map[int32]Extension)(unsafe.Pointer(uintptr(p) + uintptr(f))) } diff --git a/vendor/github.com/golang/protobuf/proto/properties.go b/vendor/github.com/golang/protobuf/proto/properties.go index 880eb22d8fdc214eee2b370cd944d413ffff28e8..dd29683c6e36ba78f8756918687cd7eeb84e90aa 100644 --- a/vendor/github.com/golang/protobuf/proto/properties.go +++ b/vendor/github.com/golang/protobuf/proto/properties.go @@ -473,17 +473,13 @@ func (p *Properties) setEncAndDec(typ reflect.Type, f *reflect.StructField, lock p.dec = (*Buffer).dec_slice_int64 p.packedDec = (*Buffer).dec_slice_packed_int64 case reflect.Uint8: - p.enc = (*Buffer).enc_slice_byte p.dec = (*Buffer).dec_slice_byte - p.size = size_slice_byte - // This is a []byte, which is either a bytes field, - // or the value of a map field. In the latter case, - // we always encode an empty []byte, so we should not - // use the proto3 enc/size funcs. - // f == nil iff this is the key/value of a map field. - if p.proto3 && f != nil { + if p.proto3 { p.enc = (*Buffer).enc_proto3_slice_byte p.size = size_proto3_slice_byte + } else { + p.enc = (*Buffer).enc_slice_byte + p.size = size_slice_byte } case reflect.Float32, reflect.Float64: switch t2.Bits() { @@ -682,7 +678,8 @@ func getPropertiesLocked(t reflect.Type) *StructProperties { propertiesMap[t] = prop // build properties - prop.extendable = reflect.PtrTo(t).Implements(extendableProtoType) + prop.extendable = reflect.PtrTo(t).Implements(extendableProtoType) || + reflect.PtrTo(t).Implements(extendableProtoV1Type) prop.unrecField = invalidField prop.Prop = make([]*Properties, t.NumField()) prop.order = make([]int, t.NumField()) @@ -693,12 +690,15 @@ func getPropertiesLocked(t reflect.Type) *StructProperties { name := f.Name p.init(f.Type, name, f.Tag.Get("protobuf"), &f, false) - if f.Name == "XXX_extensions" { // special case + if f.Name == "XXX_InternalExtensions" { // special case + p.enc = (*Buffer).enc_exts + p.dec = nil // not needed + p.size = size_exts + } else if f.Name == "XXX_extensions" { // special case p.enc = (*Buffer).enc_map p.dec = nil // not needed p.size = size_map - } - if f.Name == "XXX_unrecognized" { // special case + } else if f.Name == "XXX_unrecognized" { // special case prop.unrecField = toField(&f) } oneof := f.Tag.Get("protobuf_oneof") // special case diff --git a/vendor/github.com/golang/protobuf/proto/text.go b/vendor/github.com/golang/protobuf/proto/text.go index 37c953570d341e8cc203ffe91822512f54fa06ea..8214ce326b54d9fd35f663b41c8895f876ea61ac 100644 --- a/vendor/github.com/golang/protobuf/proto/text.go +++ b/vendor/github.com/golang/protobuf/proto/text.go @@ -455,7 +455,7 @@ func (tm *TextMarshaler) writeStruct(w *textWriter, sv reflect.Value) error { // Extensions (the XXX_extensions field). pv := sv.Addr() - if pv.Type().Implements(extendableProtoType) { + if _, ok := extendable(pv.Interface()); ok { if err := tm.writeExtensions(w, pv); err != nil { return err } @@ -513,7 +513,7 @@ func (tm *TextMarshaler) writeAny(w *textWriter, v reflect.Value, props *Propert switch v.Kind() { case reflect.Slice: // Should only be a []byte; repeated fields are handled in writeStruct. - if err := writeString(w, string(v.Interface().([]byte))); err != nil { + if err := writeString(w, string(v.Bytes())); err != nil { return err } case reflect.String: @@ -689,17 +689,22 @@ func (s int32Slice) Swap(i, j int) { s[i], s[j] = s[j], s[i] } // pv is assumed to be a pointer to a protocol message struct that is extendable. func (tm *TextMarshaler) writeExtensions(w *textWriter, pv reflect.Value) error { emap := extensionMaps[pv.Type().Elem()] - ep := pv.Interface().(extendableProto) + ep, _ := extendable(pv.Interface()) // Order the extensions by ID. // This isn't strictly necessary, but it will give us // canonical output, which will also make testing easier. - m := ep.ExtensionMap() + m, mu := ep.extensionsRead() + if m == nil { + return nil + } + mu.Lock() ids := make([]int32, 0, len(m)) for id := range m { ids = append(ids, id) } sort.Sort(int32Slice(ids)) + mu.Unlock() for _, extNum := range ids { ext := m[extNum] diff --git a/vendor/github.com/golang/protobuf/proto/text_parser.go b/vendor/github.com/golang/protobuf/proto/text_parser.go index b5fba5b2ae0971aa660f4ddc4b213eaeeea889d1..0b8c59f746eb7056c50666229ae6390d5019aa49 100644 --- a/vendor/github.com/golang/protobuf/proto/text_parser.go +++ b/vendor/github.com/golang/protobuf/proto/text_parser.go @@ -550,7 +550,7 @@ func (p *textParser) readStruct(sv reflect.Value, terminator string) error { } reqFieldErr = err } - ep := sv.Addr().Interface().(extendableProto) + ep := sv.Addr().Interface().(Message) if !rep { SetExtension(ep, desc, ext.Interface()) } else { @@ -602,8 +602,9 @@ func (p *textParser) readStruct(sv reflect.Value, terminator string) error { // The map entry should be this sequence of tokens: // < key : KEY value : VALUE > - // Technically the "key" and "value" could come in any order, - // but in practice they won't. + // However, implementations may omit key or value, and technically + // we should support them in any order. See b/28924776 for a time + // this went wrong. tok := p.next() var terminator string @@ -615,32 +616,39 @@ func (p *textParser) readStruct(sv reflect.Value, terminator string) error { default: return p.errorf("expected '{' or '<', found %q", tok.value) } - if err := p.consumeToken("key"); err != nil { - return err - } - if err := p.consumeToken(":"); err != nil { - return err - } - if err := p.readAny(key, props.mkeyprop); err != nil { - return err - } - if err := p.consumeOptionalSeparator(); err != nil { - return err - } - if err := p.consumeToken("value"); err != nil { - return err - } - if err := p.checkForColon(props.mvalprop, dst.Type().Elem()); err != nil { - return err - } - if err := p.readAny(val, props.mvalprop); err != nil { - return err - } - if err := p.consumeOptionalSeparator(); err != nil { - return err - } - if err := p.consumeToken(terminator); err != nil { - return err + for { + tok := p.next() + if tok.err != nil { + return tok.err + } + if tok.value == terminator { + break + } + switch tok.value { + case "key": + if err := p.consumeToken(":"); err != nil { + return err + } + if err := p.readAny(key, props.mkeyprop); err != nil { + return err + } + if err := p.consumeOptionalSeparator(); err != nil { + return err + } + case "value": + if err := p.checkForColon(props.mvalprop, dst.Type().Elem()); err != nil { + return err + } + if err := p.readAny(val, props.mvalprop); err != nil { + return err + } + if err := p.consumeOptionalSeparator(); err != nil { + return err + } + default: + p.back() + return p.errorf(`expected "key", "value", or %q, found %q`, terminator, tok.value) + } } dst.SetMapIndex(key, val) diff --git a/vendor/vendor.json b/vendor/vendor.json index b7373bef01987c4d3798e3743422869bd7e5db8b..2ec827ead528657fbf4290c812809140a0ffb376 100644 --- a/vendor/vendor.json +++ b/vendor/vendor.json @@ -15,10 +15,10 @@ "revisionTime": "2015-09-09T17:09:13Z" }, { - "checksumSHA1": "FczzogSoZcKU3h21tCUyHzMsnBY=", + "checksumSHA1": "Yz1gT/uux3UsjbGDG/OP7gZREY0=", "path": "github.com/golang/protobuf/proto", - "revision": "7cc19b78d562895b13596ddce7aafb59dd789318", - "revisionTime": "2016-04-25T21:53:00Z" + "revision": "3b06fc7a4cad73efce5fe6217ab6c33e7231ab4a", + "revisionTime": "2016-06-01T15:17:05Z" }, { "checksumSHA1": "bAXE13JoHB8GZ5svBJpzCWiuZBQ=", diff --git a/worker/README.md b/worker/README.md new file mode 100644 index 0000000000000000000000000000000000000000..01550d836c93a50dbf934b3e4565fd7425c9b25b --- /dev/null +++ b/worker/README.md @@ -0,0 +1,5 @@ +To update the protocol buffer definitions, run this from one directory above: + +``` +protoc -I worker worker/payload.proto --go_out=plugins=grpc:worker +``` diff --git a/worker/assign.go b/worker/assign.go index 99e24525af0eefe61153b539c9f991f54f129811..45ed5b5fb93afa06d8072f3c179cd49288e64970 100644 --- a/worker/assign.go +++ b/worker/assign.go @@ -20,7 +20,8 @@ import ( "fmt" "sync" - "github.com/dgraph-io/dgraph/conn" + "golang.org/x/net/context" + "github.com/dgraph-io/dgraph/task" "github.com/dgraph-io/dgraph/uid" "github.com/google/flatbuffers/go" @@ -93,13 +94,13 @@ func getOrAssignUids( } func GetOrAssignUidsOverNetwork(xidToUid *map[string]uint64) (rerr error) { - query := new(conn.Query) + query := new(Payload) query.Data = createXidListBuffer(*xidToUid) uo := flatbuffers.GetUOffsetT(query.Data) xidList := new(task.XidList) xidList.Init(query.Data, uo) - reply := new(conn.Reply) + reply := new(Payload) if instanceIdx == 0 { uo := flatbuffers.GetUOffsetT(query.Data) xidList := new(task.XidList) @@ -109,13 +110,22 @@ func GetOrAssignUidsOverNetwork(xidToUid *map[string]uint64) (rerr error) { if rerr != nil { return rerr } + } else { pool := pools[0] - if err := pool.Call("Worker.GetOrAssign", query, reply); err != nil { - glog.WithField("method", "GetOrAssign").WithError(err). - Error("While getting uids") + conn, err := pool.Get() + if err != nil { + glog.WithError(err).Error("Unable to retrieve connection.") return err } + c := NewWorkerClient(conn) + + reply, rerr = c.GetOrAssign(context.Background(), query) + if rerr != nil { + glog.WithField("method", "GetOrAssign").WithError(rerr). + Error("While getting uids") + return rerr + } } uidList := new(task.UidList) diff --git a/worker/conn.go b/worker/conn.go new file mode 100644 index 0000000000000000000000000000000000000000..482fecb9c52b469644b941e71880dfb5fe23719d --- /dev/null +++ b/worker/conn.go @@ -0,0 +1,71 @@ +package worker + +import ( + "log" + + "google.golang.org/grpc" +) + +type PayloadCodec struct{} + +func (cb *PayloadCodec) Marshal(v interface{}) ([]byte, error) { + p, ok := v.(*Payload) + if !ok { + log.Fatalf("Invalid type of struct: %+v", v) + } + return p.Data, nil +} + +func (cb *PayloadCodec) Unmarshal(data []byte, v interface{}) error { + p, ok := v.(*Payload) + if !ok { + log.Fatalf("Invalid type of struct: %+v", v) + } + p.Data = data + return nil +} + +func (cb *PayloadCodec) String() string { + return "worker.PayloadCodec" +} + +type Pool struct { + conns chan *grpc.ClientConn + Addr string +} + +func NewPool(addr string, maxCap int) *Pool { + p := new(Pool) + p.Addr = addr + p.conns = make(chan *grpc.ClientConn, maxCap) + conn, err := p.dialNew() + if err != nil { + glog.Fatal(err) + return nil + } + p.conns <- conn + return p +} + +func (p *Pool) dialNew() (*grpc.ClientConn, error) { + return grpc.Dial(p.Addr, grpc.WithInsecure(), grpc.WithInsecure(), + grpc.WithCodec(&PayloadCodec{})) +} + +func (p *Pool) Get() (*grpc.ClientConn, error) { + select { + case conn := <-p.conns: + return conn, nil + default: + return p.dialNew() + } +} + +func (p *Pool) Put(conn *grpc.ClientConn) error { + select { + case p.conns <- conn: + return nil + default: + return conn.Close() + } +} diff --git a/worker/mutation.go b/worker/mutation.go index 36a6666de0cf72828cf0aae23c1b3b9f6951d36e..78c2a6a5cad9a2e41c3ec06faa92dc5294cb5099 100644 --- a/worker/mutation.go +++ b/worker/mutation.go @@ -22,7 +22,8 @@ import ( "fmt" "sync" - "github.com/dgraph-io/dgraph/conn" + "golang.org/x/net/context" + "github.com/dgraph-io/dgraph/posting" "github.com/dgraph-io/dgraph/x" "github.com/dgryski/go-farm" @@ -72,7 +73,7 @@ func mutate(m *Mutations, left *Mutations) error { } func runMutate(idx int, m *Mutations, wg *sync.WaitGroup, - replies chan *conn.Reply, che chan error) { + replies chan *Payload, che chan error) { defer wg.Done() left := new(Mutations) @@ -81,17 +82,25 @@ func runMutate(idx int, m *Mutations, wg *sync.WaitGroup, return } - var err error pool := pools[idx] - query := new(conn.Query) + var err error + query := new(Payload) query.Data, err = m.Encode() if err != nil { che <- err return } - reply := new(conn.Reply) - if err := pool.Call("Worker.Mutate", query, reply); err != nil { + conn, err := pool.Get() + if err != nil { + che <- err + return + } + defer pool.Put(conn) + c := NewWorkerClient(conn) + + reply, err := c.Mutate(context.Background(), query) + if err != nil { glog.WithField("call", "Worker.Mutate"). WithField("addr", pool.Addr). WithError(err).Error("While calling mutate") @@ -116,7 +125,7 @@ func MutateOverNetwork( } var wg sync.WaitGroup - replies := make(chan *conn.Reply, numInstances) + replies := make(chan *Payload, numInstances) errors := make(chan error, numInstances) for idx, mu := range mutationArray { if mu == nil || len(mu.Set) == 0 { diff --git a/worker/payload.pb.go b/worker/payload.pb.go new file mode 100644 index 0000000000000000000000000000000000000000..0b60e38be00cddce3718716985912ac942e23d26 --- /dev/null +++ b/worker/payload.pb.go @@ -0,0 +1,231 @@ +// Code generated by protoc-gen-go. +// source: payload.proto +// DO NOT EDIT! + +/* +Package worker is a generated protocol buffer package. + +It is generated from these files: + payload.proto + +It has these top-level messages: + Payload +*/ +package worker + +import proto "github.com/golang/protobuf/proto" +import fmt "fmt" +import math "math" + +import ( + context "golang.org/x/net/context" + grpc "google.golang.org/grpc" +) + +// Reference imports to suppress errors if they are not otherwise used. +var _ = proto.Marshal +var _ = fmt.Errorf +var _ = math.Inf + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the proto package it is being compiled against. +// A compilation error at this line likely means your copy of the +// proto package needs to be updated. +const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package + +type Payload struct { + Data []byte `protobuf:"bytes,1,opt,name=Data,json=data,proto3" json:"Data,omitempty"` +} + +func (m *Payload) Reset() { *m = Payload{} } +func (m *Payload) String() string { return proto.CompactTextString(m) } +func (*Payload) ProtoMessage() {} +func (*Payload) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{0} } + +func init() { + proto.RegisterType((*Payload)(nil), "worker.Payload") +} + +// Reference imports to suppress errors if they are not otherwise used. +var _ context.Context +var _ grpc.ClientConn + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the grpc package it is being compiled against. +const _ = grpc.SupportPackageIsVersion2 + +// Client API for Worker service + +type WorkerClient interface { + Hello(ctx context.Context, in *Payload, opts ...grpc.CallOption) (*Payload, error) + GetOrAssign(ctx context.Context, in *Payload, opts ...grpc.CallOption) (*Payload, error) + Mutate(ctx context.Context, in *Payload, opts ...grpc.CallOption) (*Payload, error) + ServeTask(ctx context.Context, in *Payload, opts ...grpc.CallOption) (*Payload, error) +} + +type workerClient struct { + cc *grpc.ClientConn +} + +func NewWorkerClient(cc *grpc.ClientConn) WorkerClient { + return &workerClient{cc} +} + +func (c *workerClient) Hello(ctx context.Context, in *Payload, opts ...grpc.CallOption) (*Payload, error) { + out := new(Payload) + err := grpc.Invoke(ctx, "/worker.Worker/Hello", in, out, c.cc, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *workerClient) GetOrAssign(ctx context.Context, in *Payload, opts ...grpc.CallOption) (*Payload, error) { + out := new(Payload) + err := grpc.Invoke(ctx, "/worker.Worker/GetOrAssign", in, out, c.cc, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *workerClient) Mutate(ctx context.Context, in *Payload, opts ...grpc.CallOption) (*Payload, error) { + out := new(Payload) + err := grpc.Invoke(ctx, "/worker.Worker/Mutate", in, out, c.cc, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *workerClient) ServeTask(ctx context.Context, in *Payload, opts ...grpc.CallOption) (*Payload, error) { + out := new(Payload) + err := grpc.Invoke(ctx, "/worker.Worker/ServeTask", in, out, c.cc, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +// Server API for Worker service + +type WorkerServer interface { + Hello(context.Context, *Payload) (*Payload, error) + GetOrAssign(context.Context, *Payload) (*Payload, error) + Mutate(context.Context, *Payload) (*Payload, error) + ServeTask(context.Context, *Payload) (*Payload, error) +} + +func RegisterWorkerServer(s *grpc.Server, srv WorkerServer) { + s.RegisterService(&_Worker_serviceDesc, srv) +} + +func _Worker_Hello_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(Payload) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(WorkerServer).Hello(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/worker.Worker/Hello", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(WorkerServer).Hello(ctx, req.(*Payload)) + } + return interceptor(ctx, in, info, handler) +} + +func _Worker_GetOrAssign_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(Payload) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(WorkerServer).GetOrAssign(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/worker.Worker/GetOrAssign", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(WorkerServer).GetOrAssign(ctx, req.(*Payload)) + } + return interceptor(ctx, in, info, handler) +} + +func _Worker_Mutate_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(Payload) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(WorkerServer).Mutate(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/worker.Worker/Mutate", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(WorkerServer).Mutate(ctx, req.(*Payload)) + } + return interceptor(ctx, in, info, handler) +} + +func _Worker_ServeTask_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(Payload) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(WorkerServer).ServeTask(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/worker.Worker/ServeTask", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(WorkerServer).ServeTask(ctx, req.(*Payload)) + } + return interceptor(ctx, in, info, handler) +} + +var _Worker_serviceDesc = grpc.ServiceDesc{ + ServiceName: "worker.Worker", + HandlerType: (*WorkerServer)(nil), + Methods: []grpc.MethodDesc{ + { + MethodName: "Hello", + Handler: _Worker_Hello_Handler, + }, + { + MethodName: "GetOrAssign", + Handler: _Worker_GetOrAssign_Handler, + }, + { + MethodName: "Mutate", + Handler: _Worker_Mutate_Handler, + }, + { + MethodName: "ServeTask", + Handler: _Worker_ServeTask_Handler, + }, + }, + Streams: []grpc.StreamDesc{}, +} + +var fileDescriptor0 = []byte{ + // 151 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x09, 0x6e, 0x88, 0x02, 0xff, 0xe2, 0xe2, 0x2d, 0x48, 0xac, 0xcc, + 0xc9, 0x4f, 0x4c, 0xd1, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0x62, 0x2b, 0xcf, 0x2f, 0xca, 0x4e, + 0x2d, 0x52, 0x92, 0xe5, 0x62, 0x0f, 0x80, 0x48, 0x08, 0x09, 0x71, 0xb1, 0xb8, 0x24, 0x96, 0x24, + 0x4a, 0x30, 0x2a, 0x30, 0x6a, 0xf0, 0x04, 0xb1, 0xa4, 0x00, 0xd9, 0x46, 0xc7, 0x19, 0xb9, 0xd8, + 0xc2, 0xc1, 0x2a, 0x85, 0xb4, 0xb9, 0x58, 0x3d, 0x52, 0x73, 0x72, 0xf2, 0x85, 0xf8, 0xf5, 0x20, + 0x7a, 0xf5, 0xa0, 0x1a, 0xa5, 0xd0, 0x05, 0x94, 0x18, 0x84, 0x0c, 0xb9, 0xb8, 0xdd, 0x53, 0x4b, + 0xfc, 0x8b, 0x1c, 0x8b, 0x8b, 0x33, 0xd3, 0xf3, 0x88, 0xd2, 0xa2, 0xc3, 0xc5, 0xe6, 0x5b, 0x5a, + 0x92, 0x58, 0x92, 0x4a, 0x94, 0x6a, 0x7d, 0x2e, 0xce, 0xe0, 0xd4, 0xa2, 0xb2, 0xd4, 0x90, 0xc4, + 0xe2, 0x6c, 0x62, 0x34, 0x24, 0xb1, 0x81, 0xfd, 0x6d, 0x0c, 0x08, 0x00, 0x00, 0xff, 0xff, 0x53, + 0x6e, 0x7b, 0x2e, 0x08, 0x01, 0x00, 0x00, +} diff --git a/worker/payload.proto b/worker/payload.proto new file mode 100644 index 0000000000000000000000000000000000000000..7806379fc4e7b5b30ab92bf036ee0a5f8fd81995 --- /dev/null +++ b/worker/payload.proto @@ -0,0 +1,14 @@ +syntax = "proto3"; + +package worker; + +message Payload { + bytes Data = 1; +} + +service Worker { + rpc Hello (Payload) returns (Payload) {} + rpc GetOrAssign (Payload) returns (Payload) {} + rpc Mutate (Payload) returns (Payload) {} + rpc ServeTask (Payload) returns (Payload) {} +} diff --git a/worker/task.go b/worker/task.go index 358fe322371ea9e7a89620707bba201f7a08edf8..039e00539c762f8027556e1a459003399d83a361 100644 --- a/worker/task.go +++ b/worker/task.go @@ -17,12 +17,12 @@ package worker import ( - "github.com/dgraph-io/dgraph/conn" "github.com/dgraph-io/dgraph/posting" "github.com/dgraph-io/dgraph/task" "github.com/dgraph-io/dgraph/x" "github.com/dgryski/go-farm" "github.com/google/flatbuffers/go" + "golang.org/x/net/context" ) func ProcessTaskOverNetwork(qu []byte) (result []byte, rerr error) { @@ -52,13 +52,21 @@ func ProcessTaskOverNetwork(qu []byte) (result []byte, rerr error) { pool := pools[idx] addr := pool.Addr - query := new(conn.Query) + query := new(Payload) query.Data = qu - reply := new(conn.Reply) - if err := pool.Call("Worker.ServeTask", query, reply); err != nil { + + conn, err := pool.Get() + if err != nil { + return []byte(""), err + } + defer pool.Put(conn) + c := NewWorkerClient(conn) + reply, err := c.ServeTask(context.Background(), query) + if err != nil { glog.WithField("call", "Worker.ServeTask").Error(err) return []byte(""), err } + glog.WithField("reply_len", len(reply.Data)).WithField("addr", addr). WithField("attr", attr).Info("Got reply from server") return reply.Data, nil diff --git a/worker/worker.go b/worker/worker.go index b163cfddd64b1dc590fcc59a8e913e9eb53d75c2..549fa5c74864a8e600cd5ba1d0e2ee8860436461 100644 --- a/worker/worker.go +++ b/worker/worker.go @@ -18,11 +18,12 @@ package worker import ( "flag" - "io" "net" - "net/rpc" - "github.com/dgraph-io/dgraph/conn" + "google.golang.org/grpc" + + "golang.org/x/net/context" + "github.com/dgraph-io/dgraph/store" "github.com/dgraph-io/dgraph/task" "github.com/dgraph-io/dgraph/x" @@ -35,7 +36,7 @@ var workerPort = flag.String("workerport", ":12345", var glog = x.Log("worker") var dataStore, uidStore *store.Store -var pools []*conn.Pool +var pools []*Pool var numInstances, instanceIdx uint64 func Init(ps, uStore *store.Store, idx, numInst uint64) { @@ -45,39 +46,6 @@ func Init(ps, uStore *store.Store, idx, numInst uint64) { numInstances = numInst } -func Connect(workerList []string) { - w := new(Worker) - if err := rpc.Register(w); err != nil { - glog.Fatal(err) - } - if err := runServer(*workerPort); err != nil { - glog.Fatal(err) - } - if uint64(len(workerList)) != numInstances { - glog.WithField("len(list)", len(workerList)). - WithField("numInstances", numInstances). - Fatalf("Wrong number of instances in workerList") - } - - for _, addr := range workerList { - if len(addr) == 0 { - continue - } - pool := conn.NewPool(addr, 5) - query := new(conn.Query) - query.Data = []byte("hello") - reply := new(conn.Reply) - if err := pool.Call("Worker.Hello", query, reply); err != nil { - glog.WithField("call", "Worker.Hello").Fatal(err) - } - glog.WithField("reply", string(reply.Data)).WithField("addr", addr). - Info("Got reply from server") - pools = append(pools, pool) - } - - glog.Info("Server started. Clients connected.") -} - func NewQuery(attr string, uids []uint64) []byte { b := flatbuffers.NewBuilder(0) task.QueryStartUidsVector(b, len(uids)) @@ -95,21 +63,20 @@ func NewQuery(attr string, uids []uint64) []byte { return b.Bytes[b.Head():] } -type Worker struct { -} +type worker struct{} -func (w *Worker) Hello(query *conn.Query, reply *conn.Reply) error { - if string(query.Data) == "hello" { - reply.Data = []byte("Oh hello there!") +func (w *worker) Hello(ctx context.Context, in *Payload) (*Payload, error) { + out := new(Payload) + if string(in.Data) == "hello" { + out.Data = []byte("Oh hello there!") } else { - reply.Data = []byte("Hey stranger!") + out.Data = []byte("Hey stranger!") } - return nil -} -func (w *Worker) GetOrAssign(query *conn.Query, - reply *conn.Reply) (rerr error) { + return out, nil +} +func (w *worker) GetOrAssign(ctx context.Context, query *Payload) (*Payload, error) { uo := flatbuffers.GetUOffsetT(query.Data) xids := new(task.XidList) xids.Init(query.Data, uo) @@ -119,25 +86,31 @@ func (w *Worker) GetOrAssign(query *conn.Query, WithField("GetOrAssign", true). Fatal("We shouldn't be receiving this request.") } + + reply := new(Payload) + var rerr error reply.Data, rerr = getOrAssignUids(xids) - return + return reply, rerr } -func (w *Worker) Mutate(query *conn.Query, reply *conn.Reply) (rerr error) { +func (w *worker) Mutate(ctx context.Context, query *Payload) (*Payload, error) { m := new(Mutations) if err := m.Decode(query.Data); err != nil { - return err + return nil, err } left := new(Mutations) if err := mutate(m, left); err != nil { - return err + return nil, err } + + reply := new(Payload) + var rerr error reply.Data, rerr = left.Encode() - return + return reply, rerr } -func (w *Worker) ServeTask(query *conn.Query, reply *conn.Reply) (rerr error) { +func (w *worker) ServeTask(ctx context.Context, query *Payload) (*Payload, error) { uo := flatbuffers.GetUOffsetT(query.Data) q := new(task.Query) q.Init(query.Data, uo) @@ -145,51 +118,62 @@ func (w *Worker) ServeTask(query *conn.Query, reply *conn.Reply) (rerr error) { glog.WithField("attr", attr).WithField("num_uids", q.UidsLength()). WithField("instanceIdx", instanceIdx).Info("ServeTask") + reply := new(Payload) + var rerr error if (instanceIdx == 0 && attr == "_xid_") || farm.Fingerprint64([]byte(attr))%numInstances == instanceIdx { reply.Data, rerr = processTask(query.Data) + } else { glog.WithField("attribute", attr). WithField("instanceIdx", instanceIdx). Fatalf("Request sent to wrong server") } - return rerr + return reply, rerr } -func serveRequests(irwc io.ReadWriteCloser) { - for { - sc := &conn.ServerCodec{ - Rwc: irwc, - } - glog.Info("Serving request from serveRequests") - if err := rpc.ServeRequest(sc); err != nil { - glog.WithField("method", "serveRequests").Info(err) - break - } - } -} - -func runServer(address string) error { - ln, err := net.Listen("tcp", address) +func runServer(port string) { + ln, err := net.Listen("tcp", port) if err != nil { glog.Fatalf("While running server: %v", err) - return err + return } glog.WithField("address", ln.Addr()).Info("Worker listening") - go func() { - for { - cxn, err := ln.Accept() - if err != nil { - glog.Fatalf("listen(%q): %s\n", address, err) - return - } - glog.WithField("local", cxn.LocalAddr()). - WithField("remote", cxn.RemoteAddr()). - Debug("Worker accepted connection") - go serveRequests(cxn) + s := grpc.NewServer(grpc.CustomCodec(&PayloadCodec{})) + RegisterWorkerServer(s, &worker{}) + s.Serve(ln) +} + +func Connect(workerList []string) { + go runServer(*workerPort) + + for _, addr := range workerList { + if len(addr) == 0 { + continue + } + + pool := NewPool(addr, 5) + query := new(Payload) + query.Data = []byte("hello") + + conn, err := pool.Get() + if err != nil { + glog.WithError(err).Fatal("Unable to connect.") } - }() - return nil + + c := NewWorkerClient(conn) + reply, err := c.Hello(context.Background(), query) + if err != nil { + glog.WithError(err).Fatal("Unable to contact.") + } + _ = pool.Put(conn) + + glog.WithField("reply", string(reply.Data)).WithField("addr", addr). + Info("Got reply from server") + pools = append(pools, pool) + } + + glog.Info("Server started. Clients connected.") }