蛮荆

如何实现可靠的 UDP ?

2022-10-15

概述

UDP 是无连接的,尽最大可能交付 (网络 IP 层),不提供流量控制,拥塞控制,面向报文(对于应用程序的数据不合并也不拆分,直接添加 UDP 首部),支持一对一、一对多、多对一和多对多 (也就是多播和广播) 通信。

TCP 和 UDP 区别总结

TCP 实现可靠传输层的核心有三点:

  1. 确认与重传 (已经可以满足 “可靠性”,但是可能存在性能问题)
  2. 滑动窗口 (也就是流量控制,为了提高吞吐量,充分利用链路带宽,避免发送方发的太慢)
  3. 拥塞控制 (防止网络链路过载造成丢包,避免发送方发的太快)

如果希望使用 UDP 来实现 TCP 的可靠传输,显然 最直接的方法就是在应用层实现 确认与重传。本文沿着这个思路,看看利用 UDP 来实现确认与重传机制时的设计思路和实现伪代码。

除了确认和重传外,TCP 校验和也是 TCP 实现可靠传输的手段之一,但这不是本文的重点,所以这里一笔带过。

为了节省篇幅,本文将可靠传输的模型实现简化为:

发送方实现重传机制,接收方实现确认机制。


数据结构

为了简化代码和便于理解,下文中将每个 UDP 数据包封装一个对象,字段及含义如下所示。


// 数据包标志位类型
type FlagType uint8

const (
	FlagTypeInvalid FlagType = iota
	FlagTypeData             // 数据包
	FlagTypeAck              // 确认包
)

// 定义数据包的结构体
type Packet struct {
	Seq  int      // 序列号
	Ack  int      // 确认号
	Data string   // 数据内容
	Flag FlagType // 标志位
}

确认机制

确认机制由接收方实现,在本文中也就是服务端程序。

1. 单个数据包确认

所谓单个确认,也就是常见的 Reply 形式,发送方 (客户端) 向接收方发送一个 UDP 数据包,对于每个接收到的 UDP 数据包,接收方 (服务端) 都向发送方发送一个确认 ACK 数据包。

单个数据包确认示例图

实现思路非常简单:

  1. 服务端程序使用 UDP 监听指定端口
  2. 客户端向服务端发送 UDP 数据包
  3. 服务端收到客户端的 UDP 数据包之后,向客户端发送 ACK 数据包
  4. 客户端收到服务端的 ACK 数据包之后,更新 Seq 值

最终的程序代码及其对应的注释如下。


// V1 版本

package main

import (
	"fmt"
	"net"
	"strconv"
	"strings"
	"time"
)

// 数据包标志位类型
type FlagType uint8

const (
	FlagTypeInvalid FlagType = iota
	FlagTypeData             // 数据包
	FlagTypeAck              // 确认包
)

// 定义数据包的结构体
type Packet struct {
	Seq  int      // 序列号
	Ack  int      // 确认号
	Data string   // 数据内容
	Flag FlagType // 标志位
}

var (
	// 服务端地址
	serverAddr = net.UDPAddr{
		Port: 8080,
		IP:   net.ParseIP("127.0.0.1"),
	}
)

func main() {
	go startServer()

	// 等待服务端程序启动
	time.Sleep(200 * time.Millisecond)

	startClient()
}

// 服务端程序
func startServer() {
	conn, err := net.ListenUDP("udp", &serverAddr)
	if err != nil {
		fmt.Println("Error starting server:", err)
		return
	}
	defer conn.Close()

	buffer := make([]byte, 1024)

	for {
		n, clientAddr, err := conn.ReadFromUDP(buffer)
		if err != nil {
			fmt.Println("Error reading:", err)
			continue
		}

		// 解析接收到的数据包
		recvPacket := decode(buffer[:n])

		fmt.Printf("client -> server %s\n", serialization(&recvPacket))

		// 构造 Ack 包并发送
		ackPacket := Packet{
			// 因为这个示例中
			// 服务端不主动发送数据
			// 所以 Seq 固定为 1
			Seq:  1,
			Ack:  recvPacket.Seq + len(recvPacket.Data),
			Data: "",
			Flag: FlagTypeAck,
		}

		ackData := encode(&ackPacket)
		conn.WriteToUDP(ackData, clientAddr)
	}
}

// 客户端程序
func startClient() {
	conn, err := net.DialUDP("udp", nil, &serverAddr)
	if err != nil {
		fmt.Println("Error connecting:", err)
		return
	}
	defer conn.Close()

	// 发送一个数据包
	packet := Packet{
		Seq:  1, // 客户端 Seq 值从 1 开始
		Ack:  1,
		Data: "Hello Server",
		Flag: FlagTypeData,
	}

	// 发送 5 个 UDP 数据包
	for i := 0; i < 5; i++ {
		data := encode(&packet)
		conn.Write(data)

		// 接收 Ack 包
		buffer := make([]byte, 1024)
		n, _, err := conn.ReadFromUDP(buffer)
		if err != nil {
			fmt.Println("Error reading:", err)
			return
		}

		recvAckPacket := decode(buffer[:n])
		fmt.Printf("server -> client %s\n", serialization(&recvAckPacket))

		// 更新下次发送数据包的 Seq 值
		packet.Seq = recvAckPacket.Ack
	}
}

// Packet 数据包编码
// 使用字符串拼接作为简单实现
func encode(p *Packet) []byte {
	return []byte(fmt.Sprintf("%d|%d|%q|%d", p.Seq, p.Ack, p.Data, p.Flag))
}

// Packet 数据包解码
func decode(data []byte) Packet {
	var p Packet
	_, _ = fmt.Sscanf(string(data), "%d|%d|%q|%d", &p.Seq, &p.Ack, &p.Data, &p.Flag)
	return p
}

// 格式化数据包显示
// 模拟 WireShark 的输出格式
func serialization(p *Packet) string {
	var sb strings.Builder

	if p.Flag == FlagTypeData {
		// 无需任何标志位渲染
		// 输出占位符美化终端显示
		sb.WriteString("     ")
	} else if p.Flag == FlagTypeAck {
		sb.WriteString("[ACK]")
	} else {
		sb.WriteString("[Unknown]")
	}

	sb.WriteString(" Seq=")
	sb.WriteString(strconv.Itoa(p.Seq))

	if p.Flag == FlagTypeAck {
		sb.WriteString(" Ack=")
		sb.WriteString(strconv.Itoa(p.Ack))
	}

	sb.WriteString(" Len=")
	sb.WriteString(strconv.Itoa(len(p.Data)))

	if p.Flag == FlagTypeData {
		sb.WriteString(" Data=")
		sb.WriteString(p.Data)
	}

	return sb.String()
}

运行程序的输出如下:

通过输出结果可以看到,单个数据包的确认机制实现,已经可以正常工作了。

2. 延迟确认

服务端在不发送数据的情况下,每收到一个 UDP 数据包,就发送 Ack 报文,导致了低效的数据传输和浪费网络带宽,也就是所谓的 “糊涂窗口综合症”。

既然服务端没有什么数据要发送给客户端,那么就可以延迟一段时间再发送 Ack 报文, 如果在延迟期间,又接收到了新的数据,就可以将多个 Ack 报文合并到一个数据包里面发送了。

延迟确认示例图

当然,要重构的主要服务端程序和客户端程序的代码,修改后的代码如下所示。


// V2 版本

// 其他重复代码省略
// ...


// 服务端程序
func startServer() {
	conn, err := net.ListenUDP("udp", &serverAddr)
	if err != nil {
		fmt.Println("Error starting server:", err)
		return
	}
	defer conn.Close()

	buffer := make([]byte, 1024)

	// 延迟 200 毫秒发送 ACK
	const ackDelay = 200 * time.Millisecond

	var (
		// 延迟 Ack
		lastAck int
		// 最后发送 Ack 报文的时间
		lastAckTime = time.Now()
		// 客户端的 UDP 地址
		clientAddr *net.UDPAddr
	)

	// 因为 conn.ReadFromUDP 方法是阻塞接收操作
	// 所以这里启动一个新的 goroutine
	// 来完成延迟 Ack 操作
	go func() {
		for {
			// 超过延迟时间,发送 Ack 确认包
			if time.Since(lastAckTime) >= ackDelay {
				// 超过延迟时间,发送 Ack 确认包
				// 构造 Ack 包并发送
				ackPacket := Packet{
					// 因为这个示例中
					// 服务端不主动发送数据
					// 所以 Seq 固定为 1
					Seq:  1,
					Ack:  lastAck,
					Data: "",
					Flag: FlagTypeAck,
				}

				ackData := encode(&ackPacket)
				conn.WriteToUDP(ackData, clientAddr)

				// 更新最后发送 Ack 的时间
				lastAckTime = time.Now()
			}

			// 短暂休眠,避免占用过多 CPU 资源
			time.Sleep(100 * time.Millisecond)
		}
	}()

	for {
		_, clientAddr, err = conn.ReadFromUDP(buffer)
		if err != nil {
			fmt.Println("Error reading:", err)
			continue
		}

		// 解析接收到的数据包
		recvPacket := decode(buffer[:])

		fmt.Printf("client -> server %s\n", serialization(&recvPacket))

		// 更新最后接收到的确认号
		lastAck = recvPacket.Seq + len(recvPacket.Data)
	}
}

// 客户端程序
func startClient() {
	conn, err := net.DialUDP("udp", nil, &serverAddr)
	if err != nil {
		fmt.Println("Error connecting:", err)
		return
	}
	defer conn.Close()

	// 构建一个 UDP 数据包
	packet := Packet{
		Seq:  1, // 客户端 Seq 值从 1 开始
		Ack:  1,
		Data: "Hello Server",
		Flag: FlagTypeData,
	}

	var wg sync.WaitGroup
	wg.Add(1)

	// 这里启动一个新的 goroutine
	// 来完成接收 Ack 操作
	go func() {
		defer wg.Done()

		// 接收 Ack 包
		buffer := make([]byte, 1024)
		_, _, err := conn.ReadFromUDP(buffer)
		if err != nil {
			fmt.Println("Error reading:", err)
			return
		}

		recvAckPacket := decode(buffer[:])
		fmt.Printf("server -> client %s\n", serialization(&recvAckPacket))
	}()

	// 连续发送 5 个 UDP 数据包
	for i := 0; i < 5; i++ {
		data := encode(&packet)
		conn.Write(data)

		// 更新下次发送数据包的 Seq 值
		packet.Seq += len(packet.Data)
	}

	// 等待 Ack 报文接收完成
	wg.Wait()
}


// 其他重复代码省略
// ...

运行程序的输出如下:

通过输出结果可以看到,客户端连续发送了 5 个 UDP 数据包,但是因为服务端启动了延迟确认,最终发送给客户端的 Ack 报文只有 1 个。

3. 选择性确认

在选择性重传中,接收方通过 SAck 向发送方应答已经收到的非连续数据包,发送方可以作为依据,来重传接收方没有收到 (可能已经丢失) 的数据包。

如图所示,SAck = 1361 ~ 2721 表示这个区间的数据包已经收到了。

代码实现也很简单,就是接收方记录已经接收到的数据包 Seq,并定时将每个区间 Seq 的最大值作为 Ack 报文响应给发送方。

因为我们的伪代码只考虑接收方实现选择性确认,所以只需要在刚才的代码基础上,对服务端和客户端代码稍加调整即可。

  • 在数据包 Packet 对象中新增一个 SAck 字段
  • 修改 Packet 对象的编码、解码、渲染方法
  • 接收方 (服务端) 实现选择性确认 Ack
  • 客户端实现模拟丢包和超时自动退出

选择性确认示例图

最后修改后的代码如下所示。


// V3 版本

package main

import (
	"fmt"
	"net"
	"strconv"
	"strings"
	"sync"
	"time"
)

// 数据包标志位类型
type FlagType uint8

const (
	FlagTypeInvalid FlagType = iota
	FlagTypeData             // 数据包
	FlagTypeAck              // 确认包
)

// 定义数据包的结构体
type Packet struct {
	Seq  int      // 序列号
	Ack  int      // 确认号
	SAck string   // SAck 区间
	Data string   // 数据内容
	Flag FlagType // 标志位
}

var (
	// 服务端地址
	serverAddr = net.UDPAddr{
		Port: 8080,
		IP:   net.ParseIP("127.0.0.1"),
	}
)

func main() {
	go startServer()

	// 等待服务端程序启动
	time.Sleep(200 * time.Millisecond)

	startClient()
}

// 服务端程序
func startServer() {
	conn, err := net.ListenUDP("udp", &serverAddr)
	if err != nil {
		fmt.Println("Error starting server:", err)
		return
	}
	defer conn.Close()

	buffer := make([]byte, 32)

	// 延迟 200 毫秒发送 ACK
	const ackDelay = 200 * time.Millisecond

	var (
		// 延迟 Ack
		lastAck int

		// 记录接收到的区间 Seq
		// [0]: 区间起始 Seq
		// [1]: 区间结束 Seq, Seq + Data.Len()
		seqList = [][2]int{}

		// 最后发送 Ack 报文的时间
		lastAckTime = time.Now()
		// 客户端的 UDP 地址
		clientAddr *net.UDPAddr
	)

	// 因为 conn.ReadFromUDP 方法是阻塞接收操作
	// 所以这里启动一个新的 goroutine
	// 来完成延迟 Ack 操作
	go func() {
		for {
			// 超过延迟时间,发送 Ack 确认包
			if time.Since(lastAckTime) >= ackDelay && len(seqList) > 0 {
				// 超过延迟时间,发送 Ack 确认包
				// 构造 Ack 包并发送

				lastAck = seqList[0][1]
				lastAckChanged := false

				// 因为丢包,可能存在多个区间 Ack 确认包
				// 所以需要分开单独发送
				// 根据 Seq 合并区间
				mergedSeqList := [][2]int{
					seqList[0],
				}

				for i := 1; i < len(seqList); i++ {
					// 数据包 Seq 是连续的,直接合并两个区间
					if seqList[i][0] == mergedSeqList[len(mergedSeqList)-1][1] {
						mergedSeqList[len(mergedSeqList)-1][1] = seqList[i][1]

						// 更新最后接收到的确认号
						if !lastAckChanged {
							lastAck = mergedSeqList[len(mergedSeqList)-1][1]
						}
					} else {
						lastAckChanged = true

						// 数据包 Seq 不是连续的,有中间数据包还未收到
						mergedSeqList = append(mergedSeqList, seqList[i])
					}
				}

				for _, seq := range mergedSeqList {
					ackPacket := Packet{
						// 因为这个示例中
						// 服务端不主动发送数据
						// 所以 Seq 固定为 1
						Seq:  1,
						Ack:  lastAck,
						SAck: fmt.Sprintf("%d-%d", seq[0], seq[1]),
						Data: "",
						Flag: FlagTypeAck,
					}

					ackData := encode(&ackPacket)
					conn.WriteToUDP(ackData, clientAddr)
				}

				// 更新最后发送 Ack 的时间
				lastAckTime = time.Now()

				// 重置区间 Seq
				seqList = seqList[:0]
			}

			// 短暂休眠,避免占用过多 CPU 资源
			time.Sleep(100 * time.Millisecond)
		}
	}()

	for {
		_, clientAddr, err = conn.ReadFromUDP(buffer)
		if err != nil {
			fmt.Println("Error reading:", err)
			continue
		}

		// 解析接收到的数据包
		recvPacket := decode(buffer[:])

		fmt.Printf("client -> server %s\n", serialization(&recvPacket))

		// 记录接收到的区间 Seq
		seqList = append(seqList, [2]int{
			recvPacket.Seq,
			recvPacket.Seq + len(recvPacket.Data),
		})
	}
}

// 客户端程序
func startClient() {
	conn, err := net.DialUDP("udp", nil, &serverAddr)
	if err != nil {
		fmt.Println("Error connecting:", err)
		return
	}
	defer conn.Close()

	var wg sync.WaitGroup
	wg.Add(1)

	// 这里启动一个新的 goroutine
	// 来完成接收 Ack 操作
	go func() {
		defer wg.Done()

		t := time.NewTimer(1 * time.Second)
		defer t.Stop()

		for {
			select {
			case <-t.C:
				return
			default:
				// 接收 Ack 包
				buffer := make([]byte, 32)

				conn.SetReadDeadline(time.Now().Add(100 * time.Millisecond))
				_, _, err := conn.ReadFromUDP(buffer)
				if err != nil {
					continue
				}

				recvAckPacket := decode(buffer[:])
				fmt.Printf("server -> client %s\n", serialization(&recvAckPacket))
			}
		}
	}()

	// 构建一个 UDP 数据包
	packet := Packet{
		Seq:  1, // 客户端 Seq 值从 1 开始
		Ack:  1,
		Data: "Hello Server",
		Flag: FlagTypeData,
	}

	// 连续发送 5 个 UDP 数据包
	for i := 0; i < 5; i++ {
		// 第 4 个数据包模拟丢包
		if i != 3 {
			data := encode(&packet)
			conn.Write(data)
		}

		// 更新下次发送数据包的 Seq 值
		packet.Seq += len(packet.Data)
	}

	// 等待 Ack 报文接收完成
	wg.Wait()
}

// Packet 数据包编码
// 使用字符串拼接作为简单实现
func encode(p *Packet) []byte {
	return []byte(fmt.Sprintf("%d|%d|%q|%q|%d", p.Seq, p.Ack, p.SAck, p.Data, p.Flag))
}

// Packet 数据包解码
func decode(data []byte) Packet {
	var p Packet
	_, _ = fmt.Sscanf(string(data), "%d|%d|%q|%q|%d", &p.Seq, &p.Ack, &p.SAck, &p.Data, &p.Flag)
	return p
}

// 格式化数据包显示
// 模拟 WireShark 的输出格式
func serialization(p *Packet) string {
	var sb strings.Builder

	if p.Flag == FlagTypeData {
		// 无需任何标志位渲染
		// 输出占位符美化终端显示
		sb.WriteString("     ")
	} else if p.Flag == FlagTypeAck {
		sb.WriteString("[ACK]")
	} else {
		sb.WriteString("[Unknown]")
	}

	sb.WriteString(" Seq=")
	sb.WriteString(strconv.Itoa(p.Seq))

	if p.Flag == FlagTypeAck {
		sb.WriteString(" Ack=")
		sb.WriteString(strconv.Itoa(p.Ack))

		if len(p.SAck) > 0 {
			sb.WriteString(" SAck=")
			sb.WriteString(p.SAck)
		}
	}

	sb.WriteString(" Len=")
	sb.WriteString(strconv.Itoa(len(p.Data)))

	if p.Flag == FlagTypeData {
		sb.WriteString(" Data=")
		sb.WriteString(p.Data)
	}

	return sb.String()
}

运行程序的输出如下:

通过输出结果可以看到,客户端连续发送了 5 个 UDP 数据包,其中第 4 个包模拟丢包 (服务端接收不到),但是因为服务端启动了选择性确认,所以最终发送给客户端的 Ack 报文有 2 个:

  • Ack=37: 表示 Seq 在 37 号之前数据包已经全部接收完成
  • SAck=49-61: 表示 Seq 在 49 号到 61 号之间的数据包已经全部接收完成

客户端根据这两个信息,就可以判断出丢包的具体数据包,也就是 Seq 在 37 号到 49 号之间的数据包,具体来说,也就是下面这个数据包:


client -> server       Seq=37 Len=12 Data=Hello Server

小结

使用 UDP 实现可靠性传输中的 确认机制,接收方 (服务端) 已经完成了,接下来就是发送方 (客户端) 要实现的 重传机制。有了前文的基础后,客户端部分代码实现起来应该也很快,继续 Coding :-)


重传机制

重传机制由发送方实现,在本文中也就是客户端端程序。

1. 超时重传

为了简化实现,本文不计算数据包往返的 RTT, RTO (超时重传时间) 直接采用 1 个暴力的硬编码值: 300 毫秒。

此外,因为前文中接收方 (服务端) 已经实现了选择性确认,所以这里将 超时重传 + 选择性重传 一起实现。

最后修改后的代码如下所示。


// V4 版本

package main

import (
	"fmt"
	"net"
	"sort"
	"strconv"
	"strings"
	"sync"
	"time"
)

// 数据包标志位类型
type FlagType uint8

const (
	FlagTypeInvalid FlagType = iota
	FlagTypeData             // 数据包
	FlagTypeAck              // 确认包
)

// 定义数据包的结构体
type Packet struct {
	Seq        int      // 序列号
	Ack        int      // 确认号
	SAck       string   // SAck 区间
	Data       string   // 数据内容
	Flag       FlagType // 标志位
	Retransmit bool     // 重传标志位
}

var (
	// 服务端地址
	serverAddr = net.UDPAddr{
		Port: 8080,
		IP:   net.ParseIP("127.0.0.1"),
	}
)

func main() {
	go startServer()

	// 等待服务端程序启动
	time.Sleep(200 * time.Millisecond)

	startClient()
}

// 服务端程序
func startServer() {
	conn, err := net.ListenUDP("udp", &serverAddr)
	if err != nil {
		fmt.Println("Error starting server:", err)
		return
	}
	defer conn.Close()

	buffer := make([]byte, 32)

	// 延迟 200 毫秒发送 ACK
	const ackDelay = 200 * time.Millisecond

	var (
		// 延迟 Ack
		lastAck int

		// 记录接收到的区间 Seq
		// [0]: 区间起始 Seq
		// [1]: 区间结束 Seq, Seq + Data.Len()
		seqList = [][2]int{}

		// 记录历史接收到的所有区间 Seq
		seqRecord = [][2]int{}

		// 最后发送 Ack 报文的时间
		lastAckTime = time.Now()
		// 客户端的 UDP 地址
		clientAddr *net.UDPAddr
	)

	// 因为 conn.ReadFromUDP 方法是阻塞接收操作
	// 所以这里启动一个新的 goroutine
	// 来完成延迟 Ack 操作
	go func() {
		for {
			// 超过延迟时间,发送 Ack 确认包
			if time.Since(lastAckTime) >= ackDelay && len(seqList) > 0 {
				// 超过延迟时间,发送 Ack 确认包
				// 构造 Ack 包并发送

				lastAck = seqList[0][1]
				lastAckChanged := false

				// 因为丢包,可能存在多个区间 Ack 确认包
				// 所以需要分开单独发送
				// 根据 Seq 合并区间
				mergedSeqList := [][2]int{
					seqList[0],
				}

				for i := 1; i < len(seqList); i++ {
					// 数据包 Seq 是连续的,直接合并两个区间
					if seqList[i][0] == mergedSeqList[len(mergedSeqList)-1][1] {
						mergedSeqList[len(mergedSeqList)-1][1] = seqList[i][1]

						// 更新最后接收到的确认号
						if !lastAckChanged {
							lastAck = mergedSeqList[len(mergedSeqList)-1][1]
						}
					} else {
						lastAckChanged = true

						// 数据包 Seq 不是连续的,有中间数据包还未收到
						mergedSeqList = append(mergedSeqList, seqList[i])
					}
				}

				for _, seq := range mergedSeqList {
					ackPacket := Packet{
						// 因为这个示例中
						// 服务端不主动发送数据
						// 所以 Seq 固定为 1
						Seq:  1,
						Ack:  lastAck,
						SAck: fmt.Sprintf("%d-%d", seq[0], seq[1]),
						Data: "",
						Flag: FlagTypeAck,
					}

					ackData := encode(&ackPacket)
					conn.WriteToUDP(ackData, clientAddr)
				}

				// 更新最后发送 Ack 的时间
				lastAckTime = time.Now()

				// 重置区间 Seq
				seqList = seqList[:0]
			}

			// 短暂休眠,避免占用过多 CPU 资源
			time.Sleep(100 * time.Millisecond)
		}
	}()

	for {
		_, clientAddr, err = conn.ReadFromUDP(buffer)
		if err != nil {
			fmt.Println("Error reading:", err)
			continue
		}

		// 解析接收到的数据包
		recvPacket := decode(buffer[:])

		fmt.Printf("client -> server %s\n", serialization(&recvPacket))

		// 记录历史区间 Seq
		seqRecord = append(seqRecord, [2]int{
			recvPacket.Seq,
			recvPacket.Seq + len(recvPacket.Data),
		})

		// 这里假设重传的数据包 100% 接收成功
		// 服务端直接返回确认 Ack 报文
		// 简化对重传数据包的再次 Ack 的实现机制
		if recvPacket.Retransmit {
			// 排序合并后的区间
			sort.Slice(seqRecord, func(i, j int) bool {
				return seqRecord[i][0] < seqRecord[j][0] && seqRecord[i][1] < seqRecord[j][1]
			})
			// 合并重复区间
			// 合并重复区间
			uniqueIndex := 0
			for i := 1; i < len(seqRecord); i++ {
				if seqRecord[i][0] == seqRecord[uniqueIndex][1] {
					seqRecord[uniqueIndex][1] = seqRecord[i][1]
				} else {
					uniqueIndex++
				}
			}
			seqRecord = seqRecord[:uniqueIndex+1]

			// 更新已经接收到连续区间最大 Ack
			lastAck = seqRecord[0][1]

			recvPacket.SAck = fmt.Sprintf("%d-%d", recvPacket.Seq, recvPacket.Seq+len(recvPacket.Data))
			recvPacket.Ack = lastAck

			recvPacket.Seq = 1
			recvPacket.Flag = FlagTypeAck
			conn.WriteToUDP(encode(&recvPacket), clientAddr)
			continue
		}

		// 记录接收到的区间 Seq
		seqList = append(seqList, [2]int{
			recvPacket.Seq,
			recvPacket.Seq + len(recvPacket.Data),
		})
	}
}

// 客户端程序
func startClient() {
	conn, err := net.DialUDP("udp", nil, &serverAddr)
	if err != nil {
		fmt.Println("Error connecting:", err)
		return
	}
	defer conn.Close()

	// 记录客户端已经发送过的数据包 Seq 列表
	sentPackets := []*Packet{}
	// 记录客户端已经接收到的数据包 Seq 列表
	receivedPackets := []*Packet{}

	var wg sync.WaitGroup
	wg.Add(1)

	// 这里启动一个新的 goroutine
	// 1. 完成超时重传
	// 2. 完成接收 Ack 操作
	go func() {
		defer wg.Done()

		// 超时退出
		timeout := time.NewTimer(1 * time.Second)
		defer timeout.Stop()

		// 超时重传定时器
		// 硬编码为 300 毫秒
		ticket := time.NewTicker(300 * time.Millisecond)
		defer ticket.Stop()

		for {
			select {
			case <-timeout.C:
				return
			case <-ticket.C:
				// 发送的数据包已经被接收方全部确认
				// 无需重传
				if len(sentPackets) == len(receivedPackets) {
					continue
				}

				// 通过区间差集算法
				// 同时考虑 选择性确认 的情况
				lostPackets := []*Packet{}
				receivedAckList := [][2]int{}
				for _, val := range receivedPackets {
					ackBlock := strings.Split(val.SAck, "-")
					start, _ := strconv.ParseInt(ackBlock[0], 10, 64)
					end, _ := strconv.ParseInt(ackBlock[1], 10, 64)
					receivedAckList = append(receivedAckList, [2]int{
						int(start),
						int(end),
					})
				}

				// 排序合并后的区间
				sort.Slice(receivedAckList, func(i, j int) bool {
					return receivedAckList[i][0] < receivedAckList[j][0] && receivedAckList[i][1] < receivedAckList[j][1]
				})
				// 合并重复区间
				uniqueIndex := 0
				for i := 1; i < len(receivedAckList); i++ {
					if receivedAckList[i][0] == receivedAckList[uniqueIndex][1] {
						receivedAckList[uniqueIndex][1] = receivedAckList[i][1]
					} else {
						uniqueIndex++
					}
				}
				receivedAckList = receivedAckList[:uniqueIndex+1]

				// 计算丢失的数据包
				curRecvIndex := 0
				for i, val := range sentPackets {
					if curRecvIndex >= len(receivedPackets) {
						lostPackets = append(lostPackets, val)
						continue
					}
					if val.Seq > receivedAckList[curRecvIndex][1] {
						curRecvIndex++
						lostPackets = append(lostPackets, sentPackets[i-1])
					}
				}

				for _, val := range lostPackets {
					// 构建 1 个 UDP 数据包
					packet := Packet{
						Seq:        val.Seq,
						Ack:        1,
						Data:       "Hello Server",
						Flag:       FlagTypeData,
						Retransmit: true,
					}

					data := encode(&packet)
					conn.Write(data)
				}
			default:
				// 接收 Ack 包
				buffer := make([]byte, 32)

				conn.SetReadDeadline(time.Now().Add(100 * time.Millisecond))
				_, _, err := conn.ReadFromUDP(buffer)
				if err != nil {
					continue
				}

				recvAckPacket := decode(buffer[:])
				fmt.Printf("server -> client %s\n", serialization(&recvAckPacket))

				// 更新接收到的数据包 Seq
				receivedPackets = append(receivedPackets, &recvAckPacket)
			}
		}
	}()

	//  客户端 Seq 值从 1 开始
	curSeq := 1

	// 连续发送 5 个 UDP 数据包
	for i := 0; i < 5; i++ {
		// 构建 1 个 UDP 数据包
		packet := Packet{
			Seq:  curSeq,
			Ack:  1,
			Data: "Hello Server",
			Flag: FlagTypeData,
		}

		// 更新发送过的数据包 Seq
		sentPackets = append(sentPackets, &packet)

		// 第 4 个数据包模拟丢包
		if i != 3 {
			data := encode(&packet)
			conn.Write(data)
		}

		// 更新下次发送数据包的 Seq 值
		curSeq += len(packet.Data)
	}

	// 等待 Ack 报文接收完成
	wg.Wait()
}

// Packet 数据包编码
// 使用字符串拼接作为简单实现
func encode(p *Packet) []byte {
	return []byte(fmt.Sprintf("%d|%d|%q|%q|%d|%t", p.Seq, p.Ack, p.SAck, p.Data, p.Flag, p.Retransmit))
}

// Packet 数据包解码
func decode(data []byte) Packet {
	var p Packet
	_, _ = fmt.Sscanf(string(data), "%d|%d|%q|%q|%d|%t", &p.Seq, &p.Ack, &p.SAck, &p.Data, &p.Flag, &p.Retransmit)
	return p
}

// 格式化数据包显示
// 模拟 WireShark 的输出格式
func serialization(p *Packet) string {
	var sb strings.Builder

	if p.Retransmit {
		sb.WriteString("[TCP Retransmit] ")
	}

	if p.Flag == FlagTypeData {
		// 无需任何标志位渲染
		// 输出占位符美化终端显示
		if !p.Retransmit {
			sb.WriteString("     ")
		}
	} else if p.Flag == FlagTypeAck {
		sb.WriteString("[ACK]")
	} else {
		sb.WriteString("[Unknown]")
	}

	sb.WriteString(" Seq=")
	sb.WriteString(strconv.Itoa(p.Seq))

	if p.Flag == FlagTypeAck {
		sb.WriteString(" Ack=")
		sb.WriteString(strconv.Itoa(p.Ack))

		if len(p.SAck) > 0 {
			sb.WriteString(" SAck=")
			sb.WriteString(p.SAck)
		}
	}

	sb.WriteString(" Len=")
	sb.WriteString(strconv.Itoa(len(p.Data)))

	if p.Flag == FlagTypeData {
		sb.WriteString(" Data=")
		sb.WriteString(p.Data)
	}

	return sb.String()
}

运行程序的输出如下:

通过输出结果可以看到,客户端连续发送了 5 个 UDP 数据包,其中第 4 个包模拟丢包 (服务端接收不到),但是因为服务端启动了选择性确认,所以最终发送给客户端的 Ack 报文有 2 个:

  • Ack=37: 表示 Seq 在 37 号之前数据包已经全部接收完成
  • SAck=49-61: 表示 Seq 在 49 号到 61 号之间的数据包已经全部接收完成

客户端根据这两个信息,就可以判断出丢包的具体数据包,也就是 Seq 在 37 号到 49 号之间的数据包,具体来说,也就是下面这个数据包:


client -> server       Seq=37 Len=12 Data=Hello Server

客户端在超时计时器触发后,通过对比 已经收到的数据包 Ack已经发送的数据包 Seq 集合,计算出还未接受到的数据包,也就是丢包数据,然后重新发送,通过输出的结果,可以看到对应数据包中的 [TCP Retransmit] 标识信息。

2. 快速重传

快速重传机制依赖于重复确认(Duplicate Acknowledgments, Dup ACK)来检测数据包丢失,当接收方接收到一个乱序 (不连续) 的数据包时,会重新发送对最后一个按序 (连续) 到达的数据包的 Ack, 发送方收到一定数量 (3 个) 的重复 Ack 之后,认为数据包 (可能已经) 丢失,并立即重传该数据包。

实现方面,只需要通过在 超时重传 的代码基础上,对服务端程序略加修改,通过程序打乱接收到数据包的顺序,来模拟乱序到达,然后对于乱序的数据包,发送对应的 Dup ACK 响应报文即可。

最后修改后的代码如下所示。


// V5 版本

// 其他重复代码省略
// ...

const (
	FlagTypeInvalid FlagType = iota
	FlagTypeData             // 数据包
	FlagTypeAck              // 确认包
	FlagTypeDupAck           // 快速重传包
)


// 服务端程序
func startServer() {
	conn, err := net.ListenUDP("udp", &serverAddr)
	if err != nil {
		fmt.Println("Error starting server:", err)
		return
	}
	defer conn.Close()

	buffer := make([]byte, 32)

	// 延迟 200 毫秒发送 ACK
	const ackDelay = 200 * time.Millisecond

	var (
		// 延迟 Ack
		lastAck int

		// 记录接收到的区间 Seq
		// [0]: 区间起始 Seq
		// [1]: 区间结束 Seq, Seq + Data.Len()
		seqList = [][2]int{}

		// 记录历史接收到的所有区间 Seq
		seqRecord = [][2]int{}

		// 最后发送 Ack 报文的时间
		lastAckTime = time.Now()
		// 客户端的 UDP 地址
		clientAddr *net.UDPAddr
	)

	// 因为 conn.ReadFromUDP 方法是阻塞接收操作
	// 所以这里启动一个新的 goroutine
	// 来完成延迟 Ack 操作
	go func() {
		for {
			// 超过延迟时间,发送 Ack 确认包
			if time.Since(lastAckTime) >= ackDelay && len(seqList) > 0 {
				// 超过延迟时间,发送 Ack 确认包
				// 构造 Ack 包并发送
				lastAck = seqList[0][1]
				lastAckChanged := false

				// 程序模拟数据包乱序
				// 模拟除了第 1 个数据包之外
				// 其他的所有数据包都发生了乱序
				for i, j := 1, len(seqList)-1; i < j; i, j = i+1, j-1 {
					seqList[i], seqList[j] = seqList[j], seqList[i]
				}

				// 根据乱序数据包发送快速重传报文
				for _, val := range seqList {
					if val[0] > lastAck {
						ackPacket := Packet{
							// 因为这个示例中
							// 服务端不主动发送数据
							// 所以 Seq 固定为 1
							Seq:  1,
							Ack:  lastAck,
							SAck: "",
							Data: "",
							Flag: FlagTypeDupAck,
						}

						ackData := encode(&ackPacket)
						conn.WriteToUDP(ackData, clientAddr)
					} else {
						lastAck = val[1]
					}
				}

				// 排序合并后的区间
				sort.Slice(seqList, func(i, j int) bool {
					return seqList[i][0] < seqList[j][0] && seqList[i][1] < seqList[j][1]
				})

				// 因为丢包,可能存在多个区间 Ack 确认包
				// 所以需要分开单独发送
				// 根据 Seq 合并区间
				mergedSeqList := [][2]int{
					seqList[0],
				}

				for i := 1; i < len(seqList); i++ {
					// 数据包 Seq 是连续的,直接合并两个区间
					if seqList[i][0] == mergedSeqList[len(mergedSeqList)-1][1] {
						mergedSeqList[len(mergedSeqList)-1][1] = seqList[i][1]

						// 更新最后接收到的确认号
						if !lastAckChanged {
							lastAck = mergedSeqList[len(mergedSeqList)-1][1]
						}
					} else {
						lastAckChanged = true

						// 数据包 Seq 不是连续的,有中间数据包还未收到
						mergedSeqList = append(mergedSeqList, seqList[i])
					}
				}

				for _, seq := range mergedSeqList {
					ackPacket := Packet{
						// 因为这个示例中
						// 服务端不主动发送数据
						// 所以 Seq 固定为 1
						Seq:  1,
						Ack:  lastAck,
						SAck: fmt.Sprintf("%d-%d", seq[0], seq[1]),
						Data: "",
						Flag: FlagTypeAck,
					}

					ackData := encode(&ackPacket)
					conn.WriteToUDP(ackData, clientAddr)
				}

				// 更新最后发送 Ack 的时间
				lastAckTime = time.Now()

				// 重置区间 Seq
				seqList = seqList[:0]
			}

			// 短暂休眠,避免占用过多 CPU 资源
			time.Sleep(100 * time.Millisecond)
		}
	}()

	for {
		_, clientAddr, err = conn.ReadFromUDP(buffer)
		if err != nil {
			fmt.Println("Error reading:", err)
			continue
		}

		// 解析接收到的数据包
		recvPacket := decode(buffer[:])

		fmt.Printf("client -> server %s\n", serialization(&recvPacket))

		// 记录历史区间 Seq
		seqRecord = append(seqRecord, [2]int{
			recvPacket.Seq,
			recvPacket.Seq + len(recvPacket.Data),
		})

		// 这里假设重传的数据包 100% 接收成功
		// 服务端直接返回确认 Ack 报文
		// 简化对重传数据包的再次 Ack 的实现机制
		if recvPacket.Retransmit {
			// 排序合并后的区间
			sort.Slice(seqRecord, func(i, j int) bool {
				return seqRecord[i][0] < seqRecord[j][0] && seqRecord[i][1] < seqRecord[j][1]
			})
			// 合并重复区间
			// 合并重复区间
			uniqueIndex := 0
			for i := 1; i < len(seqRecord); i++ {
				if seqRecord[i][0] == seqRecord[uniqueIndex][1] {
					seqRecord[uniqueIndex][1] = seqRecord[i][1]
				} else {
					uniqueIndex++
				}
			}
			seqRecord = seqRecord[:uniqueIndex+1]

			// 更新已经接收到连续区间最大 Ack
			lastAck = seqRecord[0][1]

			recvPacket.SAck = fmt.Sprintf("%d-%d", recvPacket.Seq, recvPacket.Seq+len(recvPacket.Data))
			recvPacket.Ack = lastAck

			recvPacket.Seq = 1
			recvPacket.Flag = FlagTypeAck
			conn.WriteToUDP(encode(&recvPacket), clientAddr)
			continue
		}

		// 记录接收到的区间 Seq
		seqList = append(seqList, [2]int{
			recvPacket.Seq,
			recvPacket.Seq + len(recvPacket.Data),
		})
	}
}

// 客户端程序
func startClient() {
	conn, err := net.DialUDP("udp", nil, &serverAddr)
	if err != nil {
		fmt.Println("Error connecting:", err)
		return
	}
	defer conn.Close()

	// 记录客户端已经发送过的数据包 Seq 列表
	sentPackets := []*Packet{}
	// 记录客户端已经接收到的数据包 Seq 列表
	receivedPackets := []*Packet{}

	var wg sync.WaitGroup
	wg.Add(1)

	// 这里启动一个新的 goroutine
	// 1. 完成超时重传
	// 2. 完成接收 Ack 操作
	go func() {
		defer wg.Done()

		// 超时退出
		timeout := time.NewTimer(1 * time.Second)
		defer timeout.Stop()

		// 超时重传定时器
		// 硬编码为 300 毫秒
		ticket := time.NewTicker(300 * time.Millisecond)
		defer ticket.Stop()

		for {
			select {
			case <-timeout.C:
				return
			case <-ticket.C:
				// 发送的数据包已经被接收方全部确认
				// 无需重传
				if len(sentPackets) == len(receivedPackets) {
					continue
				}

				// 通过区间差集算法
				// 同时考虑 选择性确认 的情况
				lostPackets := []*Packet{}
				receivedAckList := [][2]int{}
				for _, val := range receivedPackets {
					ackBlock := strings.Split(val.SAck, "-")
					if len(ackBlock) < 2 {
						continue
					}
					start, _ := strconv.ParseInt(ackBlock[0], 10, 64)
					end, _ := strconv.ParseInt(ackBlock[1], 10, 64)
					receivedAckList = append(receivedAckList, [2]int{
						int(start),
						int(end),
					})
				}

				// 排序合并后的区间
				sort.Slice(receivedAckList, func(i, j int) bool {
					return receivedAckList[i][0] < receivedAckList[j][0] && receivedAckList[i][1] < receivedAckList[j][1]
				})
				// 合并重复区间
				uniqueIndex := 0
				for i := 1; i < len(receivedAckList); i++ {
					if receivedAckList[i][0] == receivedAckList[uniqueIndex][1] {
						receivedAckList[uniqueIndex][1] = receivedAckList[i][1]
					} else {
						uniqueIndex++
					}
				}
				receivedAckList = receivedAckList[:uniqueIndex+1]

				// 计算丢失的数据包
				curRecvIndex := 0
				for i, val := range sentPackets {
					if curRecvIndex >= len(receivedPackets) {
						lostPackets = append(lostPackets, val)
						continue
					}
					if val.Seq > receivedAckList[curRecvIndex][1] {
						curRecvIndex++
						lostPackets = append(lostPackets, sentPackets[i-1])
					}
				}

				for _, val := range lostPackets {
					// 构建 1 个 UDP 数据包
					packet := Packet{
						Seq:        val.Seq,
						Ack:        1,
						Data:       "Hello Server",
						Flag:       FlagTypeData,
						Retransmit: true,
					}

					data := encode(&packet)
					conn.Write(data)
				}
			default:
				// 接收 Ack 包
				buffer := make([]byte, 32)

				conn.SetReadDeadline(time.Now().Add(100 * time.Millisecond))
				_, _, err := conn.ReadFromUDP(buffer)
				if err != nil {
					continue
				}

				recvAckPacket := decode(buffer[:])
				fmt.Printf("server -> client %s\n", serialization(&recvAckPacket))

				// 更新接收到的数据包 Seq
				receivedPackets = append(receivedPackets, &recvAckPacket)
			}
		}
	}()

	//  客户端 Seq 值从 1 开始
	curSeq := 1

	// 连续发送 5 个 UDP 数据包
	for i := 0; i <= 5; i++ {
		// 构建 1 个 UDP 数据包
		packet := Packet{
			Seq:  curSeq,
			Ack:  1,
			Data: "Hello Server",
			Flag: FlagTypeData,
		}

		// 更新发送过的数据包 Seq
		sentPackets = append(sentPackets, &packet)

		// 第 4 个数据包模拟丢包
		if i != 3 {
			data := encode(&packet)
			conn.Write(data)
		}

		// 更新下次发送数据包的 Seq 值
		curSeq += len(packet.Data)
	}

	// 等待 Ack 报文接收完成
	wg.Wait()
}


// 格式化数据包显示
// 模拟 WireShark 的输出格式
func serialization(p *Packet) string {
	var sb strings.Builder

	if p.Retransmit {
		sb.WriteString("[TCP Retransmit] ")
	}

	if p.Flag == FlagTypeData {
		// 无需任何标志位渲染
		// 输出占位符美化终端显示
		if !p.Retransmit {
			sb.WriteString("     ")
		}
	} else if p.Flag == FlagTypeAck {
		sb.WriteString("[ACK]")
	} else if p.Flag == FlagTypeDupAck {
		sb.WriteString("[TCP Dup ACK]")
	} else {
		sb.WriteString("[Unknown]")
	}

	sb.WriteString(" Seq=")
	sb.WriteString(strconv.Itoa(p.Seq))

	if p.Flag == FlagTypeAck || p.Flag == FlagTypeDupAck {
		sb.WriteString(" Ack=")
		sb.WriteString(strconv.Itoa(p.Ack))

		if len(p.SAck) > 0 {
			sb.WriteString(" SAck=")
			sb.WriteString(p.SAck)
		}
	}

	sb.WriteString(" Len=")
	sb.WriteString(strconv.Itoa(len(p.Data)))

	if p.Flag == FlagTypeData {
		sb.WriteString(" Data=")
		sb.WriteString(p.Data)
	}

	return sb.String()
}


// 其他重复代码省略
// ...

运行程序的输出如下:

通过输出结果可以看到,客户端连续发送了 6 个 UDP 数据包,其中第 4 个包模拟丢包 (服务端接收不到),但是因为服务端启动了选择性确认,所以最终发送给客户端的 Ack 报文有 2 个:

  • Ack=37: 表示 Seq 在 37 号之前数据包已经全部接收完成
  • SAck=49-73: 表示 Seq 在 49 号到 72 号之间的数据包已经全部接收完成

客户端根据这两个信息,就可以判断出丢包的具体数据包,也就是 Seq 在 37 号到 49 号之间的数据包,具体来说,也就是下面这个数据包:


client -> server       Seq=37 Len=12 Data=Hello Server

客户端在超时计时器触发后,通过对比 已经收到的数据包 Ack已经发送的数据包 Seq 集合,计算出还未接受到的数据包,也就是丢包数据,然后重新发送,通过输出的结果,可以看到对应数据包中的 [TCP Retransmit] 标识信息。

此外,通过在服务端模拟接收到的数据包乱序,服务端向客户端发送了快速重传 Dup ACK 报文,当然,上述代码实现的是一个纯演示版本。

3. 选择性重传

在前文中 超时重传 代码实现时已经顺带实现了,这里不再赘述。


小结

本文通过伪代码实现,演示了使用 UDP 来实现 TCP 中的确认与重传机制,文中整体的所有代码实现非常粗糙简陋以及高度耦合 (可以直接运行,但只是为了演示效果),而且没有考虑任何并发安全、错误处理、性能优化等工程问题,但是本文主要的目的在于说明设计思路,伪代码可以辅助理解实现细节,能到达这个目标就足够了。

大多数有过网络编程经验的开发者,或多或少会产生过一个执念: 通过 UDP 来实现和 TCP 一样的可靠传输保证 (RUDP),但这样也就失去了创造 UDP 本身的意思,退一步说,即使真的实现了,充其量也就是和 TCP 性能持平 (毕竟 TCP 处于内核态没有上下文切换成本,RUDP 处于用户态有上下文切换成本),没有任何技术价值。此外还要考虑网络链路中的 UDP 流量服务质量 (包括运营商限制、防火墙丢包等)。不过好在,今天我们有了新的选择: QUIC, a multiplexed transport over UDP

转载申请

本作品采用 知识共享署名 4.0 国际许可协议 进行许可,转载时请注明原文链接,图片在使用时请保留全部内容,商业转载请联系作者获得授权。