引言
context.Context是Go标准库中最重要的接口之一。它提供了跨API边界和goroutine之间传递截止时间、取消信号和请求范围值的标准方式。自Go
1.7正式进入标准库以来,context已经成为Go服务端编程的基石。
然而,context的误用也是Go开发中最常见的问题之一。本文将从接口定义出发,深入分析context的源码实现,结合实际案例探讨最佳实践与需要避免的反模式。
Context接口定义
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 type Context interface { Deadline() (deadline time.Time, ok bool ) Done() <-chan struct {} Err() error Value(key interface {}) interface {} }
graph TB
subgraph "Context 接口"
I["Context"]
I --> D["Deadline() - 获取截止时间"]
I --> DN["Done() - 取消通知channel"]
I --> E["Err() - 取消原因"]
I --> V["Value() - 携带的值"]
end
Context树形结构
context通过WithCancel、WithTimeout、WithDeadline、WithValue等函数派生出子context,形成树形结构。父context取消时,所有子context也会被取消。
graph TB
BG["context.Background()"] --> C1["WithCancel"]
BG --> C2["WithTimeout(5s)"]
C1 --> C3["WithValue(userID)"]
C1 --> C4["WithTimeout(3s)"]
C2 --> C5["WithCancel"]
C2 --> C6["WithValue(traceID)"]
style BG fill:#9f9
style C1 fill:#ff9
style C2 fill:#ff9
style C3 fill:#9ff
style C4 fill:#ff9
style C5 fill:#ff9
style C6 fill:#9ff
当C1被取消时,C3和C4也会被取消。当C2超时时,C5和C6也会被取消。
源码分析
emptyCtx:根context
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 type emptyCtx struct {}func (emptyCtx) Deadline() (deadline time.Time, ok bool ) { return }func (emptyCtx) Done() <-chan struct {} { return nil }func (emptyCtx) Err() error { return nil }func (emptyCtx) Value(key interface {}) interface {} { return nil }var ( background = new (emptyCtx) todo = new (emptyCtx) )func Background () Context { return background }func TODO () Context { return todo }
cancelCtx:可取消context
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 type cancelCtx struct { Context mu sync.Mutex done atomic.Value children map [canceler]struct {} err error cause error }func (c *cancelCtx) cancel(removeFromParent bool , err, cause error ) { c.mu.Lock() if c.err != nil { c.mu.Unlock() return } c.err = err c.cause = cause d, _ := c.done.Load().(chan struct {}) if d == nil { c.done.Store(closedchan) } else { close (d) } for child := range c.children { child.cancel(false , err, cause) } c.children = nil c.mu.Unlock() if removeFromParent { removeChild(c.Context, c) } }
sequenceDiagram
participant P as 父Context
participant C as cancelCtx
participant CH as children
participant D as done channel
Note over P,D: WithCancel 创建过程
P->>C: 创建 cancelCtx{parent}
C->>P: 注册为父的child
Note over P,D: cancel() 调用过程
C->>C: 设置 err
C->>D: close(done)
loop 遍历children
C->>CH: child.cancel()
end
C->>P: 从父中移除自己
timerCtx:带超时的context
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 type timerCtx struct { cancelCtx timer *time.Timer deadline time.Time }func WithTimeout (parent Context, timeout time.Duration) (Context, CancelFunc) { return WithDeadline(parent, time.Now().Add(timeout)) }func WithDeadline (parent Context, d time.Time) (Context, CancelFunc) { if cur, ok := parent.Deadline(); ok && cur.Before(d) { return WithCancel(parent) } c := &timerCtx{ cancelCtx: newCancelCtx(parent), deadline: d, } propagateCancel(parent, c) dur := time.Until(d) if dur <= 0 { c.cancel(true , DeadlineExceeded, nil ) return c, func () { c.cancel(false , Canceled, nil ) } } c.mu.Lock() defer c.mu.Unlock() if c.err == nil { c.timer = time.AfterFunc(dur, func () { c.cancel(true , DeadlineExceeded, nil ) }) } return c, func () { c.cancel(true , Canceled, nil ) } }
valueCtx:携带值的context
1 2 3 4 5 6 7 8 9 10 11 12 type valueCtx struct { Context key, val interface {} }func (c *valueCtx) Value(key interface {}) interface {} { if c.key == key { return c.val } return value(c.Context, key) }
Value的查找是沿着context链向上遍历的,时间复杂度为O(n)。
flowchart LR
V3["valueCtx<br/>key=traceID"] -->|"not found"| V2["valueCtx<br/>key=userID"]
V2 -->|"not found"| C1["cancelCtx"]
C1 -->|"not found"| BG["Background<br/>return nil"]
Q["Value(userID)"] -.->|"查找"| V3
V2 -.->|"found!"| R["返回 userID 的值"]
style R fill:#9f9
实战用法
WithCancel:手动取消
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 package mainimport ( "context" "fmt" "time" )func worker (ctx context.Context, id int ) { for { select { case <-ctx.Done(): fmt.Printf("Worker %d: stopped, reason: %v\n" , id, ctx.Err()) return default : fmt.Printf("Worker %d: working...\n" , id) time.Sleep(500 * time.Millisecond) } } }func main () { ctx, cancel := context.WithCancel(context.Background()) for i := 1 ; i <= 3 ; i++ { go worker(ctx, i) } time.Sleep(2 * time.Second) fmt.Println("Main: cancelling all workers..." ) cancel() time.Sleep(100 * time.Millisecond) fmt.Println("Main: done" ) }
WithTimeout:超时控制
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 package mainimport ( "context" "fmt" "time" )func queryDatabase (ctx context.Context, query string ) (string , error ) { resultCh := make (chan string , 1 ) go func () { time.Sleep(3 * time.Second) resultCh <- "query result" }() select { case result := <-resultCh: return result, nil case <-ctx.Done(): return "" , fmt.Errorf("query cancelled: %w" , ctx.Err()) } }func main () { ctx, cancel := context.WithTimeout(context.Background(), 2 *time.Second) defer cancel() result, err := queryDatabase(ctx, "SELECT * FROM users" ) if err != nil { fmt.Printf("Error: %v\n" , err) return } fmt.Printf("Result: %s\n" , result) }
WithValue:传递请求范围的值
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 package mainimport ( "context" "fmt" "net/http" )type contextKey string const ( requestIDKey contextKey = "requestID" userIDKey contextKey = "userID" )func requestIDMiddleware (next http.Handler) http.Handler { return http.HandlerFunc(func (w http.ResponseWriter, r *http.Request) { reqID := r.Header.Get("X-Request-ID" ) if reqID == "" { reqID = generateID() } ctx := context.WithValue(r.Context(), requestIDKey, reqID) next.ServeHTTP(w, r.WithContext(ctx)) }) }func getRequestID (ctx context.Context) string { if id, ok := ctx.Value(requestIDKey).(string ); ok { return id } return "unknown" }func handler (w http.ResponseWriter, r *http.Request) { reqID := getRequestID(r.Context()) fmt.Fprintf(w, "Request ID: %s\n" , reqID) }func generateID () string { return "req-12345" }func main () { mux := http.NewServeMux() mux.Handle("/" , requestIDMiddleware(http.HandlerFunc(handler))) fmt.Println("Server starting on :8080" ) http.ListenAndServe(":8080" , mux) }
WithCancelCause(Go 1.20+)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 package mainimport ( "context" "errors" "fmt" )func main () { ctx, cancel := context.WithCancelCause(context.Background()) go func () { cancel(fmt.Errorf("database connection lost" )) }() <-ctx.Done() fmt.Println("Err:" , ctx.Err()) fmt.Println("Cause:" , context.Cause(ctx)) fmt.Println("Is canceled:" , errors.Is(ctx.Err(), context.Canceled)) }
AfterFunc(Go 1.21+)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 package mainimport ( "context" "fmt" "time" )func main () { ctx, cancel := context.WithTimeout(context.Background(), 2 *time.Second) defer cancel() stop := context.AfterFunc(ctx, func () { fmt.Println("Context cancelled, performing cleanup..." ) }) _ = stop <-ctx.Done() time.Sleep(100 * time.Millisecond) }
Context传播模式
HTTP客户端传播
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 package mainimport ( "context" "fmt" "io" "net/http" "time" )func fetchWithContext (ctx context.Context, url string ) (string , error ) { req, err := http.NewRequestWithContext(ctx, "GET" , url, nil ) if err != nil { return "" , fmt.Errorf("creating request: %w" , err) } resp, err := http.DefaultClient.Do(req) if err != nil { return "" , fmt.Errorf("executing request: %w" , err) } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { return "" , fmt.Errorf("reading body: %w" , err) } return string (body), nil }func main () { ctx, cancel := context.WithTimeout(context.Background(), 5 *time.Second) defer cancel() result, err := fetchWithContext(ctx, "https://httpbin.org/get" ) if err != nil { fmt.Printf("Error: %v\n" , err) return } fmt.Printf("Response length: %d\n" , len (result)) }
数据库操作传播
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 package mainimport ( "context" "database/sql" "fmt" "time" )type UserRepository struct { db *sql.DB }func (r *UserRepository) FindByID(ctx context.Context, id string ) (*User, error ) { row := r.db.QueryRowContext(ctx, "SELECT id, name, email FROM users WHERE id = $1" , id) var user User if err := row.Scan(&user.ID, &user.Name, &user.Email); err != nil { return nil , fmt.Errorf("scanning user: %w" , err) } return &user, nil }type User struct { ID string Name string Email string }type UserService struct { repo *UserRepository }func (s *UserService) GetUser(ctx context.Context, id string ) (*User, error ) { queryCtx, cancel := context.WithTimeout(ctx, 2 *time.Second) defer cancel() return s.repo.FindByID(queryCtx, id) }
反模式与最佳实践
反模式1:将Context存储在结构体中
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 type BadService struct { ctx context.Context db *sql.DB }type GoodService struct { db *sql.DB }func (s *GoodService) DoWork(ctx context.Context) error { return s.db.PingContext(ctx) }
反模式2:使用Context传递业务数据
1 2 3 4 5 6 7 8 9 10 11 12 13 14 func processOrder (ctx context.Context) error { orderID := ctx.Value("orderID" ).(string ) amount := ctx.Value("amount" ).(float64 ) }func processOrder (ctx context.Context, orderID string , amount float64 ) error { traceID := getTraceID(ctx) _ = traceID }
反模式3:忽略cancel函数
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 func leakyFunction () { ctx, _ := context.WithTimeout(context.Background(), 5 *time.Second) doWork(ctx) }func correctFunction () { ctx, cancel := context.WithTimeout(context.Background(), 5 *time.Second) defer cancel() doWork(ctx) }func doWork (ctx context.Context) {}
最佳实践汇总
graph TB
subgraph "Context 最佳实践"
A["1. context作为第一个参数<br/>命名为ctx"]
B["2. 不要传nil context<br/>不确定时用TODO()"]
C["3. WithValue只传请求范围数据<br/>traceID, requestID等"]
D["4. 始终defer cancel()"]
E["5. 不要存在struct中"]
F["6. 不要传业务参数"]
G["7. key用unexported自定义类型"]
end
性能考量
Context创建的开销
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 package mainimport ( "context" "testing" "time" )func BenchmarkWithCancel (b *testing.B) { ctx := context.Background() for i := 0 ; i < b.N; i++ { _, cancel := context.WithCancel(ctx) cancel() } }func BenchmarkWithTimeout (b *testing.B) { ctx := context.Background() for i := 0 ; i < b.N; i++ { _, cancel := context.WithTimeout(ctx, time.Second) cancel() } }func BenchmarkWithValue (b *testing.B) { ctx := context.Background() key := contextKey("key" ) for i := 0 ; i < b.N; i++ { _ = context.WithValue(ctx, key, "value" ) } }func BenchmarkDeepValueLookup (b *testing.B) { ctx := context.Background() for i := 0 ; i < 100 ; i++ { ctx = context.WithValue(ctx, contextKey(fmt.Sprintf("key-%d" , i)), i) } targetKey := contextKey("key-0" ) b.ResetTimer() for i := 0 ; i < b.N; i++ { _ = ctx.Value(targetKey) } }type contextKey string
深层的WithValue链会导致Value查找变慢(O(n))。如果需要传递多个值,可以将它们打包到一个结构体中:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 type RequestMetadata struct { TraceID string RequestID string UserAgent string }type metadataKey struct {}func WithMetadata (ctx context.Context, md *RequestMetadata) context.Context { return context.WithValue(ctx, metadataKey{}, md) }func GetMetadata (ctx context.Context) *RequestMetadata { if md, ok := ctx.Value(metadataKey{}).(*RequestMetadata); ok { return md } return nil }
总结
Context是Go并发编程和服务端开发的核心工具,正确使用它可以:
统一取消传播 :一个请求的取消能自动传播到所有相关的goroutine和IO操作
超时控制 :防止请求无限等待,保护系统资源
元数据传递 :在请求链路中传递traceID等观测数据
关键要点: - Context形成树形结构,父取消会传播到所有子context -
cancelCtx通过关闭done channel和递归取消children实现级联取消
- timerCtx基于time.AfterFunc实现自动超时取消 -
valueCtx通过链表向上查找,深度不宜过大 -
遵循最佳实践,避免将context存储在结构体中或用于传递业务参数