From f50dc198ffb525acd6b0045461555c73c8404457 Mon Sep 17 00:00:00 2001 From: Tony Blyler Date: Fri, 30 Jun 2017 17:40:36 -0400 Subject: [PATCH] Initial commit --- README.md | 2 ++ workergroup.go | 58 ++++++++++++++++++++++++++++++++++++ workergroup_test.go | 71 +++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 131 insertions(+) create mode 100644 README.md create mode 100644 workergroup.go create mode 100644 workergroup_test.go diff --git a/README.md b/README.md new file mode 100644 index 0000000..bc41b73 --- /dev/null +++ b/README.md @@ -0,0 +1,2 @@ +# go-atomic +Useful custom atomic-safe [Go](https://golang.org/) structures. diff --git a/workergroup.go b/workergroup.go new file mode 100644 index 0000000..2cacb13 --- /dev/null +++ b/workergroup.go @@ -0,0 +1,58 @@ +package workergroup + +import ( + "runtime" + "sync/atomic" + "time" +) + +const waitTime = time.Nanosecond * 10 + +// WorkerGroup acts as a sync.WaitGroup but with a max worker count. +type WorkerGroup struct { + maxWorkers uint32 + workers uint32 +} + +// MaxWorkers sets max workers to `workers` or returns current setting if < 0 +func (wg *WorkerGroup) MaxWorkers(workers uint32) uint32 { + if workers == 0 { + // update max workers to GOMAXPROCS if set to 0 + atomic.CompareAndSwapUint32(&wg.maxWorkers, 0, uint32(runtime.GOMAXPROCS(0))) + + return atomic.LoadUint32(&wg.maxWorkers) + } + + atomic.StoreUint32(&wg.maxWorkers, workers) + + return workers +} + +// Add delta to the worker count +func (wg *WorkerGroup) Add(delta uint32) { + // update max workers to GOMAXPROCES if set to 0 + atomic.CompareAndSwapUint32(&wg.maxWorkers, 0, uint32(runtime.GOMAXPROCS(0))) + + var oldCount uint32 + for oldCount = atomic.LoadUint32(&wg.workers); oldCount+delta > atomic.LoadUint32(&wg.maxWorkers); oldCount = atomic.LoadUint32(&wg.workers) { + time.Sleep(waitTime) + } + + if atomic.CompareAndSwapUint32(&wg.workers, oldCount, oldCount+delta) { + return + } + + wg.Add(delta) +} + +// Done decrement the worker counter +func (wg *WorkerGroup) Done() { + atomic.AddUint32(&wg.workers, ^uint32(0)) +} + +// Wait until worker count is zero +func (wg *WorkerGroup) Wait() { + for atomic.LoadUint32(&wg.workers) != 0 { + time.Sleep(waitTime) + } +} diff --git a/workergroup_test.go b/workergroup_test.go new file mode 100644 index 0000000..1fa8796 --- /dev/null +++ b/workergroup_test.go @@ -0,0 +1,71 @@ +package workergroup + +import ( + "runtime" + "sync/atomic" + "testing" + "time" +) + +func TestWorkerGroupMaxWorkers(t *testing.T) { + wg := WorkerGroup{} + + if wg.MaxWorkers(0) != uint32(runtime.GOMAXPROCS(0)) { + t.Error("Default max workers should be GOMAXPROCS got", wg.MaxWorkers(0)) + } + + if wg.MaxWorkers(3) != 3 { + t.Error("Failed to set max workers to 3") + } +} + +func TestWorkerGroupAdd(t *testing.T) { + wg := WorkerGroup{} + + wg.MaxWorkers(3) + + wg.Add(1) + wg.Add(1) + wg.Add(1) + + waitChan := make(chan time.Time, 1) + + start := time.Now() + go func() { + wg.Add(1) + waitChan <- time.Now() + }() + + time.Sleep(time.Millisecond * 50) + + wg.Done() + stop := <-waitChan + + if time.Duration(stop.Sub(start).Nanoseconds()) < (time.Millisecond * 50) { + t.Error("wait channel should have waited at least 50ms, waited", stop.Sub(start)) + } +} + +func TestWorkerGroupWait(t *testing.T) { + wg := WorkerGroup{} + + wg.MaxWorkers(3) + + for i := 0; i < 32; i++ { + go func() { + wg.Add(1) + time.Sleep(time.Millisecond) + wg.Done() + }() + } + + // for safety + time.Sleep(time.Millisecond) + + wg.Wait() + + time.Sleep(time.Millisecond) + if atomic.LoadUint32(&wg.workers) != 0 { + t.Error("Failed to wait until workers was actually 0") + } +}