Files

636 lines
15 KiB
Go
Raw Permalink Normal View History

2020-12-09 12:43:56 -08:00
// Copyright 2020 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package atomicptrmap
import (
"context"
"fmt"
"math/rand"
"reflect"
"runtime"
"testing"
"time"
"gvisor.dev/gvisor/pkg/sync"
)
func TestConsistencyWithGoMap(t *testing.T) {
const maxKey = 16
var vals [4]*testValue
for i := 1; /* leave vals[0] nil */ i < len(vals); i++ {
vals[i] = new(testValue)
}
var (
m = make(map[int64]*testValue)
apm testAtomicPtrMap
)
for i := 0; i < 100000; i++ {
// Apply a random operation to both m and apm and expect them to have
// the same result. Bias toward CompareAndSwap, which has the most
// cases; bias away from Range and RangeRepeatable, which are
// relatively expensive.
switch rand.Intn(10) {
case 0, 1: // Load
key := rand.Int63n(maxKey)
want := m[key]
got := apm.Load(key)
t.Logf("Load(%d) = %p", key, got)
if got != want {
t.Fatalf("got %p, wanted %p", got, want)
}
case 2, 3: // Swap
key := rand.Int63n(maxKey)
val := vals[rand.Intn(len(vals))]
want := m[key]
if val != nil {
m[key] = val
} else {
delete(m, key)
}
got := apm.Swap(key, val)
t.Logf("Swap(%d, %p) = %p", key, val, got)
if got != want {
t.Fatalf("got %p, wanted %p", got, want)
}
case 4, 5, 6, 7: // CompareAndSwap
key := rand.Int63n(maxKey)
oldVal := vals[rand.Intn(len(vals))]
newVal := vals[rand.Intn(len(vals))]
want := m[key]
if want == oldVal {
if newVal != nil {
m[key] = newVal
} else {
delete(m, key)
}
}
got := apm.CompareAndSwap(key, oldVal, newVal)
t.Logf("CompareAndSwap(%d, %p, %p) = %p", key, oldVal, newVal, got)
if got != want {
t.Fatalf("got %p, wanted %p", got, want)
}
case 8: // Range
got := make(map[int64]*testValue)
var (
haveDup = false
dup int64
)
apm.Range(func(key int64, val *testValue) bool {
if _, ok := got[key]; ok && !haveDup {
haveDup = true
dup = key
}
got[key] = val
return true
})
t.Logf("Range() = %v", got)
if !reflect.DeepEqual(got, m) {
t.Fatalf("got %v, wanted %v", got, m)
}
if haveDup {
t.Fatalf("got duplicate key %d", dup)
}
case 9: // RangeRepeatable
got := make(map[int64]*testValue)
apm.RangeRepeatable(func(key int64, val *testValue) bool {
got[key] = val
return true
})
t.Logf("RangeRepeatable() = %v", got)
if !reflect.DeepEqual(got, m) {
t.Fatalf("got %v, wanted %v", got, m)
}
}
}
}
func TestConcurrentHeterogeneous(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
var (
apm testAtomicPtrMap
wg sync.WaitGroup
)
defer func() {
cancel()
wg.Wait()
}()
possibleKeyValuePairs := make(map[int64]map[*testValue]struct{})
addKeyValuePair := func(key int64, val *testValue) {
values := possibleKeyValuePairs[key]
if values == nil {
values = make(map[*testValue]struct{})
possibleKeyValuePairs[key] = values
}
values[val] = struct{}{}
}
const numValuesPerKey = 4
// These goroutines use keys not used by any other goroutine.
const numPrivateKeys = 3
for i := 0; i < numPrivateKeys; i++ {
key := int64(i)
var vals [numValuesPerKey]*testValue
for i := 1; /* leave vals[0] nil */ i < len(vals); i++ {
val := new(testValue)
vals[i] = val
addKeyValuePair(key, val)
}
wg.Add(1)
go func() {
defer wg.Done()
r := rand.New(rand.NewSource(rand.Int63()))
var stored *testValue
for ctx.Err() == nil {
switch r.Intn(4) {
case 0:
got := apm.Load(key)
if got != stored {
t.Errorf("Load(%d): got %p, wanted %p", key, got, stored)
return
}
case 1:
val := vals[r.Intn(len(vals))]
want := stored
stored = val
got := apm.Swap(key, val)
if got != want {
t.Errorf("Swap(%d, %p): got %p, wanted %p", key, val, got, want)
return
}
case 2, 3:
oldVal := vals[r.Intn(len(vals))]
newVal := vals[r.Intn(len(vals))]
want := stored
if stored == oldVal {
stored = newVal
}
got := apm.CompareAndSwap(key, oldVal, newVal)
if got != want {
t.Errorf("CompareAndSwap(%d, %p, %p): got %p, wanted %p", key, oldVal, newVal, got, want)
return
}
}
}
}()
}
// These goroutines share a small set of keys.
const numSharedKeys = 2
var (
sharedKeys [numSharedKeys]int64
sharedValues = make(map[int64][]*testValue)
sharedValuesSet = make(map[int64]map[*testValue]struct{})
)
for i := range sharedKeys {
key := int64(numPrivateKeys + i)
sharedKeys[i] = key
vals := make([]*testValue, numValuesPerKey)
valsSet := make(map[*testValue]struct{})
for j := range vals {
val := new(testValue)
vals[j] = val
valsSet[val] = struct{}{}
addKeyValuePair(key, val)
}
sharedValues[key] = vals
sharedValuesSet[key] = valsSet
}
randSharedValue := func(r *rand.Rand, key int64) *testValue {
vals := sharedValues[key]
return vals[r.Intn(len(vals))]
}
for i := 0; i < 3; i++ {
wg.Add(1)
go func() {
defer wg.Done()
r := rand.New(rand.NewSource(rand.Int63()))
for ctx.Err() == nil {
keyIndex := r.Intn(len(sharedKeys))
key := sharedKeys[keyIndex]
var (
op string
got *testValue
)
switch r.Intn(4) {
case 0:
op = "Load"
got = apm.Load(key)
case 1:
op = "Swap"
got = apm.Swap(key, randSharedValue(r, key))
case 2, 3:
op = "CompareAndSwap"
got = apm.CompareAndSwap(key, randSharedValue(r, key), randSharedValue(r, key))
}
if got != nil {
valsSet := sharedValuesSet[key]
if _, ok := valsSet[got]; !ok {
t.Errorf("%s: got key %d, value %p; expected value in %v", op, key, got, valsSet)
return
}
}
}
}()
}
// This goroutine repeatedly searches for unused keys.
wg.Add(1)
go func() {
defer wg.Done()
r := rand.New(rand.NewSource(rand.Int63()))
for ctx.Err() == nil {
key := -1 - r.Int63()
if got := apm.Load(key); got != nil {
t.Errorf("Load(%d): got %p, wanted nil", key, got)
}
}
}()
// This goroutine repeatedly calls RangeRepeatable() and checks that each
// key corresponds to an expected value.
wg.Add(1)
go func() {
defer wg.Done()
abort := false
for !abort && ctx.Err() == nil {
apm.RangeRepeatable(func(key int64, val *testValue) bool {
values, ok := possibleKeyValuePairs[key]
if !ok {
t.Errorf("RangeRepeatable: got invalid key %d", key)
abort = true
return false
}
if _, ok := values[val]; !ok {
t.Errorf("RangeRepeatable: got key %d, value %p; expected one of %v", key, val, values)
abort = true
return false
}
return true
})
}
}()
// Finally, the main goroutine spins for the length of the test calling
// Range() and checking that each key that it observes is unique and
// corresponds to an expected value.
seenKeys := make(map[int64]struct{})
const testDuration = 5 * time.Second
end := time.Now().Add(testDuration)
abort := false
for time.Now().Before(end) {
apm.Range(func(key int64, val *testValue) bool {
values, ok := possibleKeyValuePairs[key]
if !ok {
t.Errorf("Range: got invalid key %d", key)
abort = true
return false
}
if _, ok := values[val]; !ok {
t.Errorf("Range: got key %d, value %p; expected one of %v", key, val, values)
abort = true
return false
}
if _, ok := seenKeys[key]; ok {
t.Errorf("Range: got duplicate key %d", key)
abort = true
return false
}
seenKeys[key] = struct{}{}
return true
})
if abort {
break
}
for k := range seenKeys {
delete(seenKeys, k)
}
}
}
type benchmarkableMap interface {
Load(key int64) *testValue
Store(key int64, val *testValue)
LoadOrStore(key int64, val *testValue) (*testValue, bool)
Delete(key int64)
}
// rwMutexMap implements benchmarkableMap for a RWMutex-protected Go map.
type rwMutexMap struct {
mu sync.RWMutex
m map[int64]*testValue
}
func (m *rwMutexMap) Load(key int64) *testValue {
m.mu.RLock()
defer m.mu.RUnlock()
return m.m[key]
}
func (m *rwMutexMap) Store(key int64, val *testValue) {
m.mu.Lock()
defer m.mu.Unlock()
if m.m == nil {
m.m = make(map[int64]*testValue)
}
m.m[key] = val
}
func (m *rwMutexMap) LoadOrStore(key int64, val *testValue) (*testValue, bool) {
m.mu.Lock()
defer m.mu.Unlock()
if m.m == nil {
m.m = make(map[int64]*testValue)
}
if oldVal, ok := m.m[key]; ok {
return oldVal, true
}
m.m[key] = val
return val, false
}
func (m *rwMutexMap) Delete(key int64) {
m.mu.Lock()
defer m.mu.Unlock()
delete(m.m, key)
}
// syncMap implements benchmarkableMap for a sync.Map.
type syncMap struct {
m sync.Map
}
func (m *syncMap) Load(key int64) *testValue {
val, ok := m.m.Load(key)
if !ok {
return nil
}
return val.(*testValue)
}
func (m *syncMap) Store(key int64, val *testValue) {
m.m.Store(key, val)
}
func (m *syncMap) LoadOrStore(key int64, val *testValue) (*testValue, bool) {
actual, loaded := m.m.LoadOrStore(key, val)
return actual.(*testValue), loaded
}
func (m *syncMap) Delete(key int64) {
m.m.Delete(key)
}
// benchmarkableAtomicPtrMap implements benchmarkableMap for testAtomicPtrMap.
type benchmarkableAtomicPtrMap struct {
m testAtomicPtrMap
}
func (m *benchmarkableAtomicPtrMap) Load(key int64) *testValue {
return m.m.Load(key)
}
func (m *benchmarkableAtomicPtrMap) Store(key int64, val *testValue) {
m.m.Store(key, val)
}
func (m *benchmarkableAtomicPtrMap) LoadOrStore(key int64, val *testValue) (*testValue, bool) {
if prev := m.m.CompareAndSwap(key, nil, val); prev != nil {
return prev, true
}
return val, false
}
func (m *benchmarkableAtomicPtrMap) Delete(key int64) {
m.m.Store(key, nil)
}
// benchmarkableAtomicPtrMapSharded implements benchmarkableMap for testAtomicPtrMapSharded.
type benchmarkableAtomicPtrMapSharded struct {
m testAtomicPtrMapSharded
}
func (m *benchmarkableAtomicPtrMapSharded) Load(key int64) *testValue {
return m.m.Load(key)
}
func (m *benchmarkableAtomicPtrMapSharded) Store(key int64, val *testValue) {
m.m.Store(key, val)
}
func (m *benchmarkableAtomicPtrMapSharded) LoadOrStore(key int64, val *testValue) (*testValue, bool) {
if prev := m.m.CompareAndSwap(key, nil, val); prev != nil {
return prev, true
}
return val, false
}
func (m *benchmarkableAtomicPtrMapSharded) Delete(key int64) {
m.m.Store(key, nil)
}
var mapImpls = [...]struct {
name string
ctor func() benchmarkableMap
}{
{
name: "RWMutexMap",
ctor: func() benchmarkableMap {
return new(rwMutexMap)
},
},
{
name: "SyncMap",
ctor: func() benchmarkableMap {
return new(syncMap)
},
},
{
name: "AtomicPtrMap",
ctor: func() benchmarkableMap {
return new(benchmarkableAtomicPtrMap)
},
},
{
name: "AtomicPtrMapSharded",
ctor: func() benchmarkableMap {
return new(benchmarkableAtomicPtrMapSharded)
},
},
}
func benchmarkStoreDelete(b *testing.B, mapCtor func() benchmarkableMap) {
m := mapCtor()
val := &testValue{}
for i := 0; i < b.N; i++ {
m.Store(int64(i), val)
}
for i := 0; i < b.N; i++ {
m.Delete(int64(i))
}
}
func BenchmarkStoreDelete(b *testing.B) {
for _, mapImpl := range mapImpls {
b.Run(mapImpl.name, func(b *testing.B) {
benchmarkStoreDelete(b, mapImpl.ctor)
})
}
}
func benchmarkLoadOrStoreDelete(b *testing.B, mapCtor func() benchmarkableMap) {
m := mapCtor()
val := &testValue{}
for i := 0; i < b.N; i++ {
m.LoadOrStore(int64(i), val)
}
for i := 0; i < b.N; i++ {
m.Delete(int64(i))
}
}
func BenchmarkLoadOrStoreDelete(b *testing.B) {
for _, mapImpl := range mapImpls {
b.Run(mapImpl.name, func(b *testing.B) {
benchmarkLoadOrStoreDelete(b, mapImpl.ctor)
})
}
}
func benchmarkLookupPositive(b *testing.B, mapCtor func() benchmarkableMap) {
m := mapCtor()
val := &testValue{}
for i := 0; i < b.N; i++ {
m.Store(int64(i), val)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
m.Load(int64(i))
}
}
func BenchmarkLookupPositive(b *testing.B) {
for _, mapImpl := range mapImpls {
b.Run(mapImpl.name, func(b *testing.B) {
benchmarkLookupPositive(b, mapImpl.ctor)
})
}
}
func benchmarkLookupNegative(b *testing.B, mapCtor func() benchmarkableMap) {
m := mapCtor()
val := &testValue{}
for i := 0; i < b.N; i++ {
m.Store(int64(i), val)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
m.Load(int64(-1 - i))
}
}
func BenchmarkLookupNegative(b *testing.B) {
for _, mapImpl := range mapImpls {
b.Run(mapImpl.name, func(b *testing.B) {
benchmarkLookupNegative(b, mapImpl.ctor)
})
}
}
type benchmarkConcurrentOptions struct {
// loadsPerMutationPair is the number of map lookups between each
// insertion/deletion pair.
loadsPerMutationPair int
// If changeKeys is true, the keys used by each goroutine change between
// iterations of the test.
changeKeys bool
}
func benchmarkConcurrent(b *testing.B, mapCtor func() benchmarkableMap, opts benchmarkConcurrentOptions) {
var (
started sync.WaitGroup
workers sync.WaitGroup
)
started.Add(1)
m := mapCtor()
val := &testValue{}
// Insert a large number of unused elements into the map so that used
// elements are distributed throughout memory.
for i := 0; i < 10000; i++ {
m.Store(int64(-1-i), val)
}
// n := ceil(b.N / (opts.loadsPerMutationPair + 2))
n := (b.N + opts.loadsPerMutationPair + 1) / (opts.loadsPerMutationPair + 2)
for i, procs := 0, runtime.GOMAXPROCS(0); i < procs; i++ {
workerID := i
workers.Add(1)
go func() {
defer workers.Done()
started.Wait()
for i := 0; i < n; i++ {
var key int64
if opts.changeKeys {
key = int64(workerID*n + i)
} else {
key = int64(workerID)
}
m.LoadOrStore(key, val)
for j := 0; j < opts.loadsPerMutationPair; j++ {
m.Load(key)
}
m.Delete(key)
}
}()
}
b.ResetTimer()
started.Done()
workers.Wait()
}
func BenchmarkConcurrent(b *testing.B) {
changeKeysChoices := [...]struct {
name string
val bool
}{
{"FixedKeys", false},
{"ChangingKeys", true},
}
writePcts := [...]struct {
name string
loadsPerMutationPair int
}{
{"1PercentWrites", 198},
{"10PercentWrites", 18},
{"50PercentWrites", 2},
}
for _, changeKeys := range changeKeysChoices {
for _, writePct := range writePcts {
for _, mapImpl := range mapImpls {
name := fmt.Sprintf("%s_%s_%s", changeKeys.name, writePct.name, mapImpl.name)
b.Run(name, func(b *testing.B) {
benchmarkConcurrent(b, mapImpl.ctor, benchmarkConcurrentOptions{
loadsPerMutationPair: writePct.loadsPerMutationPair,
changeKeys: changeKeys.val,
})
})
}
}
}
}