Go · #go#context#cancellation

Go Context最佳实践与源码分析

2023.08.09 Go 9 min 3.5k
// 目录 · contents

引言

context.Context是Go标准库中最重要的接口之一。它提供了跨API边界和goroutine之间传递截止时间、取消信号和请求范围值的标准方式。自Go 1.7正式进入标准库以来,context已经成为Go服务端编程的基石。

然而,context的误用也是Go开发中最常见的问题之一。本文将从接口定义出发,深入分析context的源码实现,结合实际案例探讨最佳实践与需要避免的反模式。

Context接口定义

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
// context/context.go
type Context interface {
// Deadline 返回context被取消的时间点
// ok==false 表示没有设置deadline
Deadline() (deadline time.Time, ok bool)

// Done 返回一个channel,当context被取消时关闭
// 如果context永远不会被取消,返回nil
Done() <-chan struct{}

// Err 返回context被取消的原因
// Done channel未关闭时返回nil
// 可能的值:context.Canceled 或 context.DeadlineExceeded
Err() error

// Value 返回与key关联的值
Value(key interface{}) interface{}
}
graph TB
    subgraph "Context 接口"
        I["Context"]
        I --> D["Deadline() - 获取截止时间"]
        I --> DN["Done() - 取消通知channel"]
        I --> E["Err() - 取消原因"]
        I --> V["Value() - 携带的值"]
    end

Context树形结构

context通过WithCancelWithTimeoutWithDeadlineWithValue等函数派生出子context,形成树形结构。父context取消时,所有子context也会被取消。

graph TB
    BG["context.Background()"] --> C1["WithCancel"]
    BG --> C2["WithTimeout(5s)"]
    C1 --> C3["WithValue(userID)"]
    C1 --> C4["WithTimeout(3s)"]
    C2 --> C5["WithCancel"]
    C2 --> C6["WithValue(traceID)"]

    style BG fill:#9f9
    style C1 fill:#ff9
    style C2 fill:#ff9
    style C3 fill:#9ff
    style C4 fill:#ff9
    style C5 fill:#ff9
    style C6 fill:#9ff

当C1被取消时,C3和C4也会被取消。当C2超时时,C5和C6也会被取消。

源码分析

emptyCtx:根context

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
// context.Background() 和 context.TODO() 返回的类型
type emptyCtx struct{}

func (emptyCtx) Deadline() (deadline time.Time, ok bool) { return }
func (emptyCtx) Done() <-chan struct{} { return nil }
func (emptyCtx) Err() error { return nil }
func (emptyCtx) Value(key interface{}) interface{} { return nil }

var (
background = new(emptyCtx)
todo = new(emptyCtx)
)

func Background() Context { return background }
func TODO() Context { return todo }

cancelCtx:可取消context

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
// WithCancel 返回的context类型
type cancelCtx struct {
Context // 嵌入父context

mu sync.Mutex // 保护以下字段
done atomic.Value // chan struct{}, lazily created
children map[canceler]struct{} // 子context集合
err error // 取消原因
cause error // 取消的根因 (Go 1.20+)
}

// cancel 方法的实现
func (c *cancelCtx) cancel(removeFromParent bool, err, cause error) {
c.mu.Lock()
if c.err != nil {
c.mu.Unlock()
return // 已经被取消
}
c.err = err
c.cause = cause

// 关闭done channel,通知所有等待者
d, _ := c.done.Load().(chan struct{})
if d == nil {
c.done.Store(closedchan)
} else {
close(d)
}

// 递归取消所有子context
for child := range c.children {
child.cancel(false, err, cause)
}
c.children = nil
c.mu.Unlock()

if removeFromParent {
removeChild(c.Context, c)
}
}
sequenceDiagram
    participant P as 父Context
    participant C as cancelCtx
    participant CH as children
    participant D as done channel

    Note over P,D: WithCancel 创建过程
    P->>C: 创建 cancelCtx{parent}
    C->>P: 注册为父的child

    Note over P,D: cancel() 调用过程
    C->>C: 设置 err
    C->>D: close(done)
    loop 遍历children
        C->>CH: child.cancel()
    end
    C->>P: 从父中移除自己

timerCtx:带超时的context

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
type timerCtx struct {
cancelCtx // 嵌入cancelCtx
timer *time.Timer // 定时器
deadline time.Time // 截止时间
}

func WithTimeout(parent Context, timeout time.Duration) (Context, CancelFunc) {
return WithDeadline(parent, time.Now().Add(timeout))
}

func WithDeadline(parent Context, d time.Time) (Context, CancelFunc) {
// 如果父context的deadline更早,直接用WithCancel
if cur, ok := parent.Deadline(); ok && cur.Before(d) {
return WithCancel(parent)
}

c := &timerCtx{
cancelCtx: newCancelCtx(parent),
deadline: d,
}
propagateCancel(parent, c)

dur := time.Until(d)
if dur <= 0 {
c.cancel(true, DeadlineExceeded, nil) // 已过期
return c, func() { c.cancel(false, Canceled, nil) }
}

c.mu.Lock()
defer c.mu.Unlock()
if c.err == nil {
// 设置定时器,到期自动取消
c.timer = time.AfterFunc(dur, func() {
c.cancel(true, DeadlineExceeded, nil)
})
}
return c, func() { c.cancel(true, Canceled, nil) }
}

valueCtx:携带值的context

1
2
3
4
5
6
7
8
9
10
11
12
type valueCtx struct {
Context // 嵌入父context
key, val interface{}
}

func (c *valueCtx) Value(key interface{}) interface{} {
if c.key == key {
return c.val
}
// 向上递归查找
return value(c.Context, key)
}

Value的查找是沿着context链向上遍历的,时间复杂度为O(n)。

flowchart LR
    V3["valueCtx<br/>key=traceID"] -->|"not found"| V2["valueCtx<br/>key=userID"]
    V2 -->|"not found"| C1["cancelCtx"]
    C1 -->|"not found"| BG["Background<br/>return nil"]

    Q["Value(userID)"] -.->|"查找"| V3
    V2 -.->|"found!"| R["返回 userID 的值"]

    style R fill:#9f9

实战用法

WithCancel:手动取消

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
package main

import (
"context"
"fmt"
"time"
)

func worker(ctx context.Context, id int) {
for {
select {
case <-ctx.Done():
fmt.Printf("Worker %d: stopped, reason: %v\n", id, ctx.Err())
return
default:
fmt.Printf("Worker %d: working...\n", id)
time.Sleep(500 * time.Millisecond)
}
}
}

func main() {
ctx, cancel := context.WithCancel(context.Background())

// 启动多个worker
for i := 1; i <= 3; i++ {
go worker(ctx, i)
}

// 运行2秒后取消
time.Sleep(2 * time.Second)
fmt.Println("Main: cancelling all workers...")
cancel()

// 等待worker处理完取消信号
time.Sleep(100 * time.Millisecond)
fmt.Println("Main: done")
}

WithTimeout:超时控制

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
package main

import (
"context"
"fmt"
"time"
)

// 模拟数据库查询
func queryDatabase(ctx context.Context, query string) (string, error) {
// 模拟慢查询
resultCh := make(chan string, 1)
go func() {
time.Sleep(3 * time.Second) // 模拟耗时操作
resultCh <- "query result"
}()

select {
case result := <-resultCh:
return result, nil
case <-ctx.Done():
return "", fmt.Errorf("query cancelled: %w", ctx.Err())
}
}

func main() {
// 设置2秒超时
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel() // 即使正常完成也要调用cancel释放资源

result, err := queryDatabase(ctx, "SELECT * FROM users")
if err != nil {
fmt.Printf("Error: %v\n", err)
return
}
fmt.Printf("Result: %s\n", result)
}

WithValue:传递请求范围的值

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
package main

import (
"context"
"fmt"
"net/http"
)

// 使用自定义类型作为key,避免冲突
type contextKey string

const (
requestIDKey contextKey = "requestID"
userIDKey contextKey = "userID"
)

// 中间件:注入请求ID
func requestIDMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
reqID := r.Header.Get("X-Request-ID")
if reqID == "" {
reqID = generateID() // 生成唯一ID
}
ctx := context.WithValue(r.Context(), requestIDKey, reqID)
next.ServeHTTP(w, r.WithContext(ctx))
})
}

// 从context中获取请求ID
func getRequestID(ctx context.Context) string {
if id, ok := ctx.Value(requestIDKey).(string); ok {
return id
}
return "unknown"
}

func handler(w http.ResponseWriter, r *http.Request) {
reqID := getRequestID(r.Context())
fmt.Fprintf(w, "Request ID: %s\n", reqID)
}

func generateID() string {
return "req-12345" // 实际应用中使用UUID
}

func main() {
mux := http.NewServeMux()
mux.Handle("/", requestIDMiddleware(http.HandlerFunc(handler)))
fmt.Println("Server starting on :8080")
http.ListenAndServe(":8080", mux)
}

WithCancelCause(Go 1.20+)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
package main

import (
"context"
"errors"
"fmt"
)

func main() {
ctx, cancel := context.WithCancelCause(context.Background())

go func() {
// 带原因的取消
cancel(fmt.Errorf("database connection lost"))
}()

<-ctx.Done()
fmt.Println("Err:", ctx.Err()) // context canceled
fmt.Println("Cause:", context.Cause(ctx)) // database connection lost
fmt.Println("Is canceled:", errors.Is(ctx.Err(), context.Canceled)) // true
}

AfterFunc(Go 1.21+)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
package main

import (
"context"
"fmt"
"time"
)

func main() {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()

// 当context取消时执行清理函数
stop := context.AfterFunc(ctx, func() {
fmt.Println("Context cancelled, performing cleanup...")
})

// 如果在context取消前就不需要cleanup了,可以取消注册
_ = stop // stop() 会取消注册

<-ctx.Done()
time.Sleep(100 * time.Millisecond) // 等待AfterFunc执行
}

Context传播模式

HTTP客户端传播

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
package main

import (
"context"
"fmt"
"io"
"net/http"
"time"
)

func fetchWithContext(ctx context.Context, url string) (string, error) {
// 创建带context的HTTP请求
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
return "", fmt.Errorf("creating request: %w", err)
}

resp, err := http.DefaultClient.Do(req)
if err != nil {
return "", fmt.Errorf("executing request: %w", err)
}
defer resp.Body.Close()

body, err := io.ReadAll(resp.Body)
if err != nil {
return "", fmt.Errorf("reading body: %w", err)
}

return string(body), nil
}

func main() {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()

result, err := fetchWithContext(ctx, "https://httpbin.org/get")
if err != nil {
fmt.Printf("Error: %v\n", err)
return
}
fmt.Printf("Response length: %d\n", len(result))
}

数据库操作传播

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
package main

import (
"context"
"database/sql"
"fmt"
"time"
)

type UserRepository struct {
db *sql.DB
}

func (r *UserRepository) FindByID(ctx context.Context, id string) (*User, error) {
// context传递给数据库查询,超时会自动取消查询
row := r.db.QueryRowContext(ctx,
"SELECT id, name, email FROM users WHERE id = $1", id)

var user User
if err := row.Scan(&user.ID, &user.Name, &user.Email); err != nil {
return nil, fmt.Errorf("scanning user: %w", err)
}
return &user, nil
}

type User struct {
ID string
Name string
Email string
}

// 服务层传播context
type UserService struct {
repo *UserRepository
}

func (s *UserService) GetUser(ctx context.Context, id string) (*User, error) {
// 可以进一步设置更细粒度的超时
queryCtx, cancel := context.WithTimeout(ctx, 2*time.Second)
defer cancel()

return s.repo.FindByID(queryCtx, id)
}

反模式与最佳实践

反模式1:将Context存储在结构体中

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
// 错误:不要把context存在结构体中
type BadService struct {
ctx context.Context // 不要这样做!
db *sql.DB
}

// 正确:context应该作为函数的第一个参数传递
type GoodService struct {
db *sql.DB
}

func (s *GoodService) DoWork(ctx context.Context) error {
// context作为参数传递
return s.db.PingContext(ctx)
}

反模式2:使用Context传递业务数据

1
2
3
4
5
6
7
8
9
10
11
12
13
14
// 错误:不要用context传递业务逻辑所需的参数
func processOrder(ctx context.Context) error {
orderID := ctx.Value("orderID").(string) // 不要这样做!
amount := ctx.Value("amount").(float64) // 不要这样做!
// ...
}

// 正确:业务数据作为函数参数显式传递
func processOrder(ctx context.Context, orderID string, amount float64) error {
// context只用于传递请求范围的元数据(traceID、requestID等)
traceID := getTraceID(ctx) // 这是合适的用法
_ = traceID
// ...
}

反模式3:忽略cancel函数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
// 错误:泄漏context资源
func leakyFunction() {
ctx, _ := context.WithTimeout(context.Background(), 5*time.Second)
// cancel被丢弃了!timer资源不会被提前释放
doWork(ctx)
}

// 正确:始终调用cancel
func correctFunction() {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel() // 确保资源释放
doWork(ctx)
}

func doWork(ctx context.Context) {}

最佳实践汇总

graph TB
    subgraph "Context 最佳实践"
        A["1. context作为第一个参数<br/>命名为ctx"]
        B["2. 不要传nil context<br/>不确定时用TODO()"]
        C["3. WithValue只传请求范围数据<br/>traceID, requestID等"]
        D["4. 始终defer cancel()"]
        E["5. 不要存在struct中"]
        F["6. 不要传业务参数"]
        G["7. key用unexported自定义类型"]
    end

性能考量

Context创建的开销

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
package main

import (
"context"
"testing"
"time"
)

func BenchmarkWithCancel(b *testing.B) {
ctx := context.Background()
for i := 0; i < b.N; i++ {
_, cancel := context.WithCancel(ctx)
cancel()
}
}

func BenchmarkWithTimeout(b *testing.B) {
ctx := context.Background()
for i := 0; i < b.N; i++ {
_, cancel := context.WithTimeout(ctx, time.Second)
cancel()
}
}

func BenchmarkWithValue(b *testing.B) {
ctx := context.Background()
key := contextKey("key")
for i := 0; i < b.N; i++ {
_ = context.WithValue(ctx, key, "value")
}
}

// 深层Value查找的开销
func BenchmarkDeepValueLookup(b *testing.B) {
ctx := context.Background()
// 构建深度为100的context链
for i := 0; i < 100; i++ {
ctx = context.WithValue(ctx, contextKey(fmt.Sprintf("key-%d", i)), i)
}
targetKey := contextKey("key-0") // 最深层的key

b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = ctx.Value(targetKey)
}
}

type contextKey string

深层的WithValue链会导致Value查找变慢(O(n))。如果需要传递多个值,可以将它们打包到一个结构体中:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
// 将多个值打包为一个
type RequestMetadata struct {
TraceID string
RequestID string
UserAgent string
}

type metadataKey struct{}

func WithMetadata(ctx context.Context, md *RequestMetadata) context.Context {
return context.WithValue(ctx, metadataKey{}, md)
}

func GetMetadata(ctx context.Context) *RequestMetadata {
if md, ok := ctx.Value(metadataKey{}).(*RequestMetadata); ok {
return md
}
return nil
}

总结

Context是Go并发编程和服务端开发的核心工具,正确使用它可以:

  1. 统一取消传播:一个请求的取消能自动传播到所有相关的goroutine和IO操作
  2. 超时控制:防止请求无限等待,保护系统资源
  3. 元数据传递:在请求链路中传递traceID等观测数据

关键要点: - Context形成树形结构,父取消会传播到所有子context - cancelCtx通过关闭done channel和递归取消children实现级联取消 - timerCtx基于time.AfterFunc实现自动超时取消 - valueCtx通过链表向上查找,深度不宜过大 - 遵循最佳实践,避免将context存储在结构体中或用于传递业务参数

作者 · authorzt
发布 · date2023-08-09
篇幅 · length3.5k 字 · 9 min
许可 · licenseCC BY-SA 4.0
$ echo "comments" · 评论