Go · #generics#go#type-parameters

Go泛型深入解析与实战应用

2023.10.25 Go 10 min 4.2k
// 目录 · contents

引言

Go 1.18引入了泛型(Generics),这是Go语言自诞生以来最大的语法变化。泛型允许我们编写类型参数化的函数和数据结构,在保持类型安全的同时减少代码重复。

在泛型出现之前,Go开发者不得不为不同类型编写几乎相同的代码,或使用interface{}牺牲类型安全。泛型的引入为Go打开了新的编程范式。本文将全面解析Go泛型的语法、约束系统、类型推断机制,并通过实际案例展示其在数据结构和算法中的应用。

泛型基础语法

类型参数(Type Parameters)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
package main

import "fmt"

// 泛型函数:T是类型参数,any是约束
func Max[T interface{ ~int | ~float64 | ~string }](a, b T) T {
if a > b {
return a
}
return b
}

// 多个类型参数
func Map[T any, R any](slice []T, fn func(T) R) []R {
result := make([]R, len(slice))
for i, v := range slice {
result[i] = fn(v)
}
return result
}

func main() {
// 显式指定类型参数
fmt.Println(Max[int](3, 5)) // 5
fmt.Println(Max[string]("a", "b")) // b

// 类型推断(编译器自动推断类型参数)
fmt.Println(Max(3.14, 2.71)) // 3.14

// Map函数
nums := []int{1, 2, 3, 4, 5}
doubled := Map(nums, func(n int) int { return n * 2 })
fmt.Println(doubled) // [2 4 6 8 10]

strs := Map(nums, func(n int) string {
return fmt.Sprintf("item-%d", n)
})
fmt.Println(strs) // [item-1 item-2 item-3 item-4 item-5]
}

泛型类型(Generic Types)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
package main

import "fmt"

// 泛型栈
type Stack[T any] struct {
items []T
}

func NewStack[T any]() *Stack[T] {
return &Stack[T]{items: make([]T, 0)}
}

func (s *Stack[T]) Push(item T) {
s.items = append(s.items, item)
}

func (s *Stack[T]) Pop() (T, bool) {
if len(s.items) == 0 {
var zero T
return zero, false
}
item := s.items[len(s.items)-1]
s.items = s.items[:len(s.items)-1]
return item, true
}

func (s *Stack[T]) Peek() (T, bool) {
if len(s.items) == 0 {
var zero T
return zero, false
}
return s.items[len(s.items)-1], true
}

func (s *Stack[T]) Len() int {
return len(s.items)
}

func main() {
// int栈
intStack := NewStack[int]()
intStack.Push(1)
intStack.Push(2)
intStack.Push(3)

for intStack.Len() > 0 {
val, _ := intStack.Pop()
fmt.Println(val) // 3, 2, 1
}

// string栈
strStack := NewStack[string]()
strStack.Push("hello")
strStack.Push("world")
top, _ := strStack.Peek()
fmt.Println(top) // world
}

约束(Constraints)

内置约束

graph TB
    subgraph "Go 内置约束层次"
        ANY["any<br/>(interface{})"]
        COMPARABLE["comparable<br/>(支持 == 和 !=)"]

        ANY --> COMPARABLE

        subgraph "constraints 包 (golang.org/x/exp)"
            ORDERED["Ordered<br/>(支持 < > <= >=)"]
            SIGNED["Signed<br/>(有符号整数)"]
            UNSIGNED["Unsigned<br/>(无符号整数)"]
            INTEGER["Integer<br/>(所有整数)"]
            FLOAT["Float<br/>(浮点数)"]
            COMPLEX["Complex<br/>(复数)"]
        end
    end

    style ANY fill:#9f9
    style COMPARABLE fill:#ff9
    style ORDERED fill:#9ff

自定义约束

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
package main

import (
"fmt"
"strings"
)

// 基本约束
type Number interface {
~int | ~int8 | ~int16 | ~int32 | ~int64 |
~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 |
~float32 | ~float64
}

// 有序约束
type Ordered interface {
Number | ~string
}

// 带方法的约束
type Stringer interface {
String() string
}

// 组合约束
type OrderedStringer interface {
Ordered
Stringer
}

// ~T 表示底层类型为T的所有类型
type MyInt int

func (m MyInt) String() string {
return fmt.Sprintf("MyInt(%d)", m)
}

// Sum 使用Number约束
func Sum[T Number](nums []T) T {
var total T
for _, n := range nums {
total += n
}
return total
}

// Contains 使用comparable约束
func Contains[T comparable](slice []T, target T) bool {
for _, v := range slice {
if v == target {
return true
}
}
return false
}

// Sort 使用Ordered约束
func Sort[T Ordered](slice []T) {
// 简单的插入排序
for i := 1; i < len(slice); i++ {
key := slice[i]
j := i - 1
for j >= 0 && slice[j] > key {
slice[j+1] = slice[j]
j--
}
slice[j+1] = key
}
}

func main() {
// Number约束
fmt.Println(Sum([]int{1, 2, 3, 4, 5})) // 15
fmt.Println(Sum([]float64{1.1, 2.2, 3.3})) // 6.6

// comparable约束
fmt.Println(Contains([]string{"a", "b", "c"}, "b")) // true
fmt.Println(Contains([]int{1, 2, 3}, 4)) // false

// Ordered约束
ints := []int{5, 3, 1, 4, 2}
Sort(ints)
fmt.Println(ints) // [1 2 3 4 5]

strs := []string{"banana", "apple", "cherry"}
Sort(strs)
fmt.Println(strs) // [apple banana cherry]

_ = strings.Builder{}
}

类型集合语义

graph LR
    subgraph "约束即类型集合"
        A["interface { int | string }"] --> |类型集合| B["{int, string}"]
        C["interface { ~int }"] --> |底层类型| D["{int, MyInt, YourInt, ...}"]
        E["interface { int; String() string }"] --> |交集| F["{实现了String()的int}"]
    end
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
// 约束中的类型集合

// 联合约束:类型集合的并集
type IntOrString interface {
~int | ~string
}

// ~T 包含所有底层类型为T的类型
type Numeric interface {
~int | ~float64
}

// 接口约束:同时要求类型集合和方法
type StringLike interface {
~string | ~[]byte
Len() int // 注意:string没有Len()方法,这个约束实际上只匹配自定义类型
}

// 实际中更常见的组合方式
type Formatter interface {
comparable
Format() string
}

类型推断

Go编译器能在很多情况下自动推断类型参数:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
package main

import "fmt"

func Filter[T any](slice []T, predicate func(T) bool) []T {
var result []T
for _, v := range slice {
if predicate(v) {
result = append(result, v)
}
}
return result
}

func Reduce[T any, R any](slice []T, initial R, fn func(R, T) R) R {
result := initial
for _, v := range slice {
result = fn(result, v)
}
return result
}

func GroupBy[T any, K comparable](slice []T, keyFn func(T) K) map[K][]T {
result := make(map[K][]T)
for _, item := range slice {
key := keyFn(item)
result[key] = append(result[key], item)
}
return result
}

type Person struct {
Name string
Age int
City string
}

func main() {
people := []Person{
{"Alice", 30, "Beijing"},
{"Bob", 25, "Shanghai"},
{"Charlie", 35, "Beijing"},
{"Diana", 28, "Shanghai"},
{"Eve", 32, "Guangzhou"},
}

// 类型推断:编译器从people和lambda推断出T=Person
adults := Filter(people, func(p Person) bool {
return p.Age >= 30
})
fmt.Println("Adults:", adults)

// 类型推断:T=Person, R=int
totalAge := Reduce(people, 0, func(sum int, p Person) int {
return sum + p.Age
})
fmt.Println("Total age:", totalAge)

// 类型推断:T=Person, K=string
byCity := GroupBy(people, func(p Person) string {
return p.City
})
for city, residents := range byCity {
fmt.Printf("%s: %v\n", city, residents)
}
}

泛型数据结构

泛型链表

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
package main

import "fmt"

type Node[T any] struct {
Value T
Next *Node[T]
}

type LinkedList[T any] struct {
Head *Node[T]
Tail *Node[T]
Len int
}

func NewLinkedList[T any]() *LinkedList[T] {
return &LinkedList[T]{}
}

func (l *LinkedList[T]) PushBack(value T) {
node := &Node[T]{Value: value}
if l.Tail == nil {
l.Head = node
l.Tail = node
} else {
l.Tail.Next = node
l.Tail = node
}
l.Len++
}

func (l *LinkedList[T]) PushFront(value T) {
node := &Node[T]{Value: value, Next: l.Head}
l.Head = node
if l.Tail == nil {
l.Tail = node
}
l.Len++
}

func (l *LinkedList[T]) PopFront() (T, bool) {
if l.Head == nil {
var zero T
return zero, false
}
value := l.Head.Value
l.Head = l.Head.Next
if l.Head == nil {
l.Tail = nil
}
l.Len--
return value, true
}

// ForEach 遍历链表
func (l *LinkedList[T]) ForEach(fn func(T)) {
current := l.Head
for current != nil {
fn(current.Value)
current = current.Next
}
}

// ToSlice 转换为切片
func (l *LinkedList[T]) ToSlice() []T {
result := make([]T, 0, l.Len)
l.ForEach(func(v T) {
result = append(result, v)
})
return result
}

func main() {
list := NewLinkedList[int]()
list.PushBack(1)
list.PushBack(2)
list.PushBack(3)
list.PushFront(0)

fmt.Println(list.ToSlice()) // [0 1 2 3]

val, ok := list.PopFront()
fmt.Printf("Popped: %d (ok=%v)\n", val, ok) // Popped: 0 (ok=true)
fmt.Printf("Length: %d\n", list.Len) // Length: 3
}

泛型有序Map(基于红黑树简化版)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
package main

import "fmt"

// 约束定义
type Ordered interface {
~int | ~int8 | ~int16 | ~int32 | ~int64 |
~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 |
~float32 | ~float64 | ~string
}

// 简化的有序Map(基于排序切片)
type OrderedMap[K Ordered, V any] struct {
keys []K
values []V
}

func NewOrderedMap[K Ordered, V any]() *OrderedMap[K, V] {
return &OrderedMap[K, V]{}
}

func (m *OrderedMap[K, V]) Set(key K, value V) {
idx := m.search(key)
if idx < len(m.keys) && m.keys[idx] == key {
m.values[idx] = value
return
}
// 插入新元素
m.keys = append(m.keys, key)
m.values = append(m.values, value)
// 移动元素保持有序
for i := len(m.keys) - 1; i > idx; i-- {
m.keys[i] = m.keys[i-1]
m.values[i] = m.values[i-1]
}
m.keys[idx] = key
m.values[idx] = value
}

func (m *OrderedMap[K, V]) Get(key K) (V, bool) {
idx := m.search(key)
if idx < len(m.keys) && m.keys[idx] == key {
return m.values[idx], true
}
var zero V
return zero, false
}

// 二分查找
func (m *OrderedMap[K, V]) search(key K) int {
lo, hi := 0, len(m.keys)
for lo < hi {
mid := lo + (hi-lo)/2
if m.keys[mid] < key {
lo = mid + 1
} else {
hi = mid
}
}
return lo
}

func (m *OrderedMap[K, V]) Keys() []K {
result := make([]K, len(m.keys))
copy(result, m.keys)
return result
}

func (m *OrderedMap[K, V]) Len() int {
return len(m.keys)
}

func main() {
om := NewOrderedMap[string, int]()
om.Set("cherry", 3)
om.Set("apple", 1)
om.Set("banana", 2)

fmt.Println("Keys (ordered):", om.Keys()) // [apple banana cherry]

if val, ok := om.Get("banana"); ok {
fmt.Println("banana =", val) // banana = 2
}
}

泛型算法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
package main

import "fmt"

type Ordered interface {
~int | ~int8 | ~int16 | ~int32 | ~int64 |
~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 |
~float32 | ~float64 | ~string
}

// Unique 去重(保持顺序)
func Unique[T comparable](slice []T) []T {
seen := make(map[T]struct{})
result := make([]T, 0)
for _, v := range slice {
if _, ok := seen[v]; !ok {
seen[v] = struct{}{}
result = append(result, v)
}
}
return result
}

// Chunk 将切片分成固定大小的块
func Chunk[T any](slice []T, size int) [][]T {
if size <= 0 {
return nil
}
var chunks [][]T
for i := 0; i < len(slice); i += size {
end := i + size
if end > len(slice) {
end = len(slice)
}
chunks = append(chunks, slice[i:end])
}
return chunks
}

// Zip 将两个切片合并为元组切片
type Pair[A, B any] struct {
First A
Second B
}

func Zip[A, B any](a []A, b []B) []Pair[A, B] {
minLen := len(a)
if len(b) < minLen {
minLen = len(b)
}
result := make([]Pair[A, B], minLen)
for i := 0; i < minLen; i++ {
result[i] = Pair[A, B]{First: a[i], Second: b[i]}
}
return result
}

// Min 返回切片中的最小值
func Min[T Ordered](slice []T) (T, bool) {
if len(slice) == 0 {
var zero T
return zero, false
}
min := slice[0]
for _, v := range slice[1:] {
if v < min {
min = v
}
}
return min, true
}

func main() {
// Unique
nums := []int{1, 2, 3, 2, 1, 4, 3, 5}
fmt.Println("Unique:", Unique(nums)) // [1 2 3 4 5]

// Chunk
data := []int{1, 2, 3, 4, 5, 6, 7}
fmt.Println("Chunks:", Chunk(data, 3)) // [[1 2 3] [4 5 6] [7]]

// Zip
names := []string{"Alice", "Bob", "Charlie"}
ages := []int{30, 25, 35}
pairs := Zip(names, ages)
for _, p := range pairs {
fmt.Printf("%s: %d\n", p.First, p.Second)
}

// Min
min, _ := Min([]int{5, 3, 8, 1, 9})
fmt.Println("Min:", min) // 1
}

泛型的局限性

graph TB
    subgraph "Go泛型的局限性"
        A["不支持方法级别的类型参数<br/>方法只能使用结构体的类型参数"]
        B["不支持特化<br/>无法为特定类型提供优化实现"]
        C["不支持运算符重载<br/>自定义类型不能直接用 < > +"]
        D["不支持类型参数的可变参数"]
        E["约束不支持字段访问<br/>只能约束方法"]
    end
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
// 局限1:方法不能有自己的类型参数
type Container[T any] struct {
items []T
}

// 错误:方法不能引入新的类型参数
// func (c *Container[T]) Transform[R any](fn func(T) R) *Container[R] { ... }

// 正确:使用包级函数代替
func Transform[T, R any](c *Container[T], fn func(T) R) *Container[R] {
result := &Container[R]{items: make([]R, len(c.items))}
for i, item := range c.items {
result.items[i] = fn(item)
}
return result
}

// 局限2:不支持特化
// 无法写出类似这样的代码:
// func Sum[T Number](s []T) T { ... } // 通用版本
// func Sum[int](s []int) int { ... } // int特化版本(不支持)

// 局限3:约束不能访问字段
// 错误:
// type HasName interface {
// Name string // 不能约束字段
// }
// 只能约束方法:
type HasName interface {
GetName() string
}

性能影响

Go泛型使用GCShape stenciling策略:相同GC形状(GC shape)的类型共享同一份机器代码。

graph LR
    subgraph "GCShape Stenciling"
        F["func Max[T Ordered](a, b T) T"]
        F --> S1["Max_shape_int<br/>(int, int8, int16...)"]
        F --> S2["Max_shape_ptr<br/>(*T 所有指针类型)"]
        F --> S3["Max_shape_string<br/>(string)"]
    end
    Note["相同GC shape的类型<br/>共享一份代码+字典"] -.-> F
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
package main

import "testing"

// 基准测试:泛型 vs 具体类型
func maxInt(a, b int) int {
if a > b {
return a
}
return b
}

func maxGeneric[T interface{ ~int | ~float64 }](a, b T) T {
if a > b {
return a
}
return b
}

func BenchmarkMaxConcrete(b *testing.B) {
for i := 0; i < b.N; i++ {
maxInt(3, 5)
}
}

func BenchmarkMaxGeneric(b *testing.B) {
for i := 0; i < b.N; i++ {
maxGeneric(3, 5)
}
}

// 在大多数情况下,性能差异可以忽略不计
// 编译器能够内联泛型函数

标准库中的泛型

Go 1.21引入的slicesmaps包提供了开箱即用的泛型工具:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
package main

import (
"fmt"
"maps"
"slices"
)

func main() {
// slices包
nums := []int{5, 3, 8, 1, 9, 2}

// 排序
slices.Sort(nums)
fmt.Println("Sorted:", nums) // [1 2 3 5 8 9]

// 二分查找
idx, found := slices.BinarySearch(nums, 5)
fmt.Printf("Found 5 at index %d: %v\n", idx, found)

// 包含
fmt.Println("Contains 3:", slices.Contains(nums, 3))

// 最大最小值
fmt.Println("Max:", slices.Max(nums))
fmt.Println("Min:", slices.Min(nums))

// 去重(需要先排序)
duped := []int{1, 1, 2, 2, 3, 3}
slices.Sort(duped)
unique := slices.Compact(duped)
fmt.Println("Compact:", unique) // [1 2 3]

// maps包
m1 := map[string]int{"a": 1, "b": 2, "c": 3}
m2 := map[string]int{"a": 1, "b": 2, "c": 3}

fmt.Println("Equal:", maps.Equal(m1, m2)) // true

// 克隆
m3 := maps.Clone(m1)
fmt.Println("Clone:", m3)

// 获取所有keys
keys := slices.Sorted(maps.Keys(m1))
fmt.Println("Keys:", keys) // [a b c]
}

总结

Go泛型为Go语言带来了重要的表达能力:

  1. 类型参数允许函数和类型在多种类型上工作,保持类型安全
  2. 约束系统基于接口的类型集合语义,灵活而强大
  3. 类型推断使得大部分场景下无需显式指定类型参数
  4. 泛型数据结构(Stack、LinkedList、OrderedMap等)消除了重复代码
  5. 泛型算法(Map、Filter、Reduce等)提升了代码复用性

使用建议: - 只在确实需要时使用泛型,不要为了用而用 - 优先使用标准库的slicesmaps包 - 当发现自己为多种类型写相同逻辑时,考虑泛型 - 约束尽量窄:用Ordered代替any,当需要比较时 - 注意泛型的局限性,必要时退回到接口+类型断言

作者 · authorzt
发布 · date2023-10-25
篇幅 · length4.2k 字 · 10 min
许可 · licenseCC BY-SA 4.0
$ echo "comments" · 评论