蛮荆

sync.WaitGroup Code Reading

2023-04-25

概述

sync.WaitGroup 可以等待一个并发执行的 goroutine 集合执行结束。

示例

通过一个小例子展示 sync.WaitGroup 的使用方法。

package main

import (
	"fmt"
	"strconv"
	"sync"
	"time"
)

type Task struct {
	ID   int
	Name string
}

func main() {
	tasks := make([]*Task, 0)

	// 添加 5 个任务
	for i := 1; i <= 5; i++ {
		tasks = append(tasks, &Task{
			ID:   i,
			Name: strconv.Itoa(i),
		})
	}

	var wg sync.WaitGroup

	// 开启多个 goroutine 并行执行任务
	for _, task := range tasks {
		wg.Add(1)

		go func(t *Task) {
			defer wg.Done() // 任务完成

			fmt.Printf("Task %s starting ...\n", t.Name)

			time.Sleep(300 * time.Millisecond) // 模拟任务执行耗时
		}(task)
	}

	wg.Wait() // 等待所有任务执行结束
}
$ go run main.go

# 输出如下
Task 5 starting ...
Task 1 starting ...
Task 4 starting ...
Task 3 starting ...
Task 2 starting ...

从输出的结果中可以看到,虽然任务执行完成顺序和添加顺序并不一致,但是最终 5 个任务全部执行完成。

内部实现

我们来探究一下 sync.WaitGroup 的内部实现,文件路径为 $GOROOT/src/sync/waitgroup.go,笔者的 Go 版本为 go1.19 linux/amd64

WaitGroup 对象

WaitGroup 对象表示并发 goroutine 集合的控制器,具体的使用方法为:

  • main goroutine 调用 Add 方法设置对象需要等待的 goroutine 数量
  • main goroutine 调用 Wait 方法 阻塞等待 goroutine 执行结束
  • 其他 goroutine 在执行结束时调用 Done 方法通知 main goroutine

根据 Go 内存模型的约束,goroutine 调用 Done 方法时,必须在对应的 Wait 方法之前调用,否则对应的 Wait 方法将永远阻塞。

// WaitGroup 一旦使用后,就不能再复制
type WaitGroup struct {
	noCopy noCopy // 保证编译期间不会发生复制
	
	state1 uint64
	state2 uint32
}

两个字段表示的三个变量

三个语义变量

state1state2 两个字段其实表示了三个语义变量,分别为:

  • counter: (计数器) 当前执行的 goroutine (调用了 Add 方法) 数量
  • waiter: (等待者) 当前等待的 goroutine (调用了 Wait 方法) 数量
  • sema: (信号量) 用于休眠或唤醒 goroutine

为什么不直接设置三个变量呢?

因为 counter 和 waiter 计数器根据内存对齐情况放进一个 64 位整数里面,这是标准库做的一个优化,将两个计数器放进一个变量,这样就可以在不加锁的情况下,支持并发场景下的原子操作了,极大地提高了性能

state 方法

state 方法返回两个指针变量,statep 变量表示 counter 和 waiter 计数器,semap 变量表示信号量。

stete 方法会根据 state1 字段的内存对齐位数,在必要时动态 “交换” 三个语义变量的顺序

64 位对齐

在 32 位架构中,WaitGroup 对象初始化时分配的内存地址是随机的,state1 字段起始的位置不一定 64 位对齐,所以需要和 state2 字段拼接起来,实现内存连续的情况下保证 64 位对齐。

非 64 位对齐

func (wg *WaitGroup) state() (statep *uint64, semap *uint32) {
	// 判断 state1 字段是否按照 64 位对齐
	if unsafe.Alignof(wg.state1) == 8 || uintptr(unsafe.Pointer(&wg.state1))%8 == 0 {
		// 如果 state1 字段是 64 位对齐,直接返回
		return &wg.state1, &wg.state2
	} else {
		// 如果 state1 是 32 位而非 64 位对齐
		// 这意味着 (&state1)+4 是 64 位对齐 (state1 字段 + 4, 正好是 state2 字段)
		// (&state1)+4 等于跨了两个字段,所以是 64 位对齐 (两个字段的内存是连续的)
		// 最后把两个字段地址进行连接,在连接的基础上实现地址交换
		state := (*[3]uint32)(unsafe.Pointer(&wg.state1))
		return (*uint64)(unsafe.Pointer(&state[1])), &state[0]
	}
}

Add 方法

Add 方法增加 delta 个计数,内部会添加到对应的 counter 计数器上,如果 counter 变为 0,所有阻塞在 Wait 方法上的 goroutine 都会立即完成并被释放。

具体的调用规则如下:

  1. 当 counter == 0 并且 delta > 0 时,必须在 Wait 方法之前调用 Add 方法
  2. 当 counter > 0 并且 delta < 0 时,可以在任何时候调用 Add 方法
  3. 一般情况下,Add 方法应该在创建 goroutine 时或其他阻塞场景发生前调用
  4. 如果 WaitGroup 要重复使用,应该在所有 Wait 方法返回之后再继续调用 Add 方法
func (wg *WaitGroup) Add(delta int) {
	statep, semap := wg.state() // 调用 state() 取出计数器和信号量
	
    ...
	
	state := atomic.AddUint64(statep, uint64(delta)<<32) // 增加计数器的值 
	v := int32(state >> 32) // 获取计数器的值 (高位字节)
	w := uint32(state)  // 获取等待者的值 (低位字节)
	
	...
	
	if v < 0 {
        // 计数器不能为负数 (出现了 BUG)
		panic("sync: negative WaitGroup counter")
	}
	
	// 等待者不等于 0, 说明已经有 goroutine 调用了 Wait 方法
	// 此时不允许再调用 Add 方法了 (参考规则 4)
	if w != 0 && delta > 0 && v == int32(delta) {
		panic("sync: WaitGroup misuse: Add called concurrently with Wait")
	}
	
	if v > 0 || w == 0 {
        // 如果计数器大于 0 或者没有等待者,直接返回
		return
	}
	
	// 当等待者大于 0 并且计数器等于 0 (所有 goroutine 都调用了 Done 方法表示其结束执行)
	// 重置计数器和等待者为 0
	*statep = 0
	// 唤醒所有等待者 (逐个阻塞调用)
	for ; w != 0; w-- {
		runtime_Semrelease(semap, false, 0)
	}
}

Done 方法

Done 方法简单地封装了一下 Add 方法 (等于调用 Add(-1)),提供了一个可读性更高的操作原语。

func (wg *WaitGroup) Done() {
	wg.Add(-1)
}

Wait 方法

Wait 方法会进入阻塞,直到计数器的值等于 0。

func (wg *WaitGroup) Wait() {
	statep, semap := wg.state() // 调用 state() 取出计数器和信号量

	...
	    
	for {
		state := atomic.LoadUint64(statep)
        v := int32(state >> 32) // 获取计数器的值 (高位字节)
        w := uint32(state)  // 获取等待者的值 (低位字节)
		if v == 0 {
			// 计数器等于 0,直接返回
			return
		}
		
		// 计数器不等于 0,说明存在并发
		// 增加等待者的值
		if atomic.CompareAndSwapUint64(statep, state, state+1) {
            ...
			
			// 休眠当前 goroutine 等待唤醒
			runtime_Semacquire(semap)
			if *statep != 0 {
				// 等待者不等于 0, 说明 WaitGroup 对象被重复使用了 (参考规则 4)
				panic("sync: WaitGroup is reused before previous Wait has returned")
			}
			
            ...
			
			return
		}
	}
}

noCopy 对象

noCopy 对象可以添加到具体的结构体中,实现 “首次使用之后,无法被复制” 的功能 (由编译器实现)。

noCopy.Lock 方法是一个空操作,由 go vet 工具链中的 -copylocks checker 参数指令使用。

type noCopy struct{}

func (*noCopy) Lock()   {}
func (*noCopy) Unlock() {}

小结

sync.WaitGroup 的代码实现中,有两个非常重要的优化技巧值得我们学习:

  • 通过将多个变量放入一个变量,实现无加锁的原子操作
  • state 方法不仅提供了标准的 Go 内存对齐 检测方法, 同时 通过将连续的地址变换为数组,实现内存交换

转载申请

本作品采用 知识共享署名 4.0 国际许可协议 进行许可,转载时请注明原文链接,图片在使用时请保留全部内容,商业转载请联系作者获得授权。