diff --git a/concurrent/map.go b/concurrent/map.go index 41612fe93ebaa52083bab491e0dcdc3e6da23dd2..05844df911e7551737ea4cd908c910850b35cd28 100644 --- a/concurrent/map.go +++ b/concurrent/map.go @@ -18,6 +18,7 @@ package concurrent import ( "log" + "math/rand" "sync/atomic" "unsafe" ) @@ -31,14 +32,21 @@ type bucket struct { elems [8]kv } +const ( + MUTABLE = iota + IMMUTABLE +) + type container struct { - active int32 - sz uint64 - list []*bucket + status int32 + sz uint64 + list []*bucket + numElems uint32 } type Map struct { - cs [2]container + cs [2]unsafe.Pointer + size uint32 } func powOf2(sz int) bool { @@ -46,7 +54,7 @@ func powOf2(sz int) bool { } func initContainer(cs *container, sz uint64) { - cs.active = 1 + cs.status = MUTABLE cs.sz = sz cs.list = make([]*bucket, sz) for i := range cs.list { @@ -59,47 +67,98 @@ func NewMap(sz int) *Map { log.Fatal("Map can only be created for a power of 2.") } + c := new(container) + initContainer(c, uint64(sz)) + m := new(Map) - initContainer(&m.cs[0], uint64(sz)) + m.cs[MUTABLE] = unsafe.Pointer(c) + m.cs[IMMUTABLE] = nil return m } -func (m *Map) Get(k uint64) unsafe.Pointer { - for _, c := range m.cs { - if atomic.LoadInt32(&c.active) == 0 { - continue +func (c *container) get(k uint64) unsafe.Pointer { + bi := k & (c.sz - 1) + b := c.list[bi] + for i := range b.elems { + e := &b.elems[i] + if ek := atomic.LoadUint64(&e.k); ek == k { + return e.v } - bi := k & (c.sz - 1) - b := c.list[bi] - for i := range b.elems { - e := &b.elems[i] - ek := atomic.LoadUint64(&e.k) - if ek == k { - return e.v + } + return nil +} + +func (c *container) getOrInsert(k uint64, v unsafe.Pointer) unsafe.Pointer { + bi := k & (c.sz - 1) + b := c.list[bi] + for i := range b.elems { + e := &b.elems[i] + // Once allocated a valid key, it would never change. So, first check if + // it's allocated. If not, then allocate it. If can't, or not allocated, + // then check if it's k. If it is, then replace value. Otherwise continue. + // This sequence could be problematic, if this happens: + // Main thread runs Step 1. Check + if atomic.CompareAndSwapUint64(&e.k, 0, k) { // Step 1. + atomic.AddUint32(&c.numElems, 1) + if atomic.CompareAndSwapPointer(&e.v, nil, v) { + return v } + return atomic.LoadPointer(&e.v) + } + + if atomic.LoadUint64(&e.k) == k { + // Swap if previous pointer is nil. + if atomic.CompareAndSwapPointer(&e.v, nil, v) { + return v + } + return atomic.LoadPointer(&e.v) + } + } + return nil +} + +func (m *Map) GetOrInsert(k uint64, v unsafe.Pointer) unsafe.Pointer { + if v == nil { + log.Fatal("GetOrInsert doesn't allow setting nil pointers.") + return nil + } + + // Check immutable first. + cval := atomic.LoadPointer(&m.cs[IMMUTABLE]) + if cval != nil { + c := (*container)(cval) + if pv := c.get(k); pv != nil { + return pv } } + + // Okay, deal with mutable container now. + cval = atomic.LoadPointer(&m.cs[MUTABLE]) + if cval == nil { + log.Fatal("This is disruptive in a bad way.") + } + c := (*container)(cval) + if pv := c.getOrInsert(k, v); pv != nil { + return pv + } + + // We still couldn't insert the key. Time to grow. + // TODO: Handle this case. return nil } -func (m *Map) Put(k uint64, v unsafe.Pointer) bool { +func (m *Map) SetNilIfPresent(k uint64) bool { for _, c := range m.cs { - if atomic.LoadInt32(&c.active) == 0 { + if atomic.LoadInt32(&c.status) == 0 { continue } bi := k & (c.sz - 1) b := c.list[bi] for i := range b.elems { e := &b.elems[i] - // Once allocated a valid key, it would never change. So, first check if - // it's allocated. If not, then allocate it. If can't, or not allocated, - // then check if it's k. If it is, then replace value. Otherwise continue. - if atomic.CompareAndSwapUint64(&e.k, 0, k) { - atomic.StorePointer(&e.v, v) - return true - } if atomic.LoadUint64(&e.k) == k { - atomic.StorePointer(&e.v, v) + // Set to nil. + atomic.StorePointer(&e.v, nil) return true } } @@ -107,24 +166,30 @@ func (m *Map) Put(k uint64, v unsafe.Pointer) bool { return false } -/* func (m *Map) StreamUntilCap(ch chan uint64) { - for _, c := range m.cs { - if atomic.LoadInt32(&c.active) == 0 { - continue + for { + ci := rand.Intn(2) + c := m.cs[ci] + if atomic.LoadInt32(&c.status) == 0 { + ci += 1 + c = m.cs[ci%2] // use the other. } - for { - bi := rand.Intn(int(c.sz)) - for len(ch) < cap(ch) { + bi := rand.Intn(int(c.sz)) + + for _, e := range c.list[bi].elems { + if len(ch) >= cap(ch) { + return + } + if k := atomic.LoadUint64(&e.k); k > 0 { + ch <- k } } } } -*/ func (m *Map) StreamAll(ch chan uint64) { for _, c := range m.cs { - if atomic.LoadInt32(&c.active) == 0 { + if atomic.LoadInt32(&c.status) == 0 { continue } for i := 0; i < int(c.sz); i++ {