gin框架的中间件问题
问题
今天有一个同学来问我问题, 大概是这样的:
package main
import (
"fmt"
"github.com/gin-gonic/gin"
)
func CORSMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
origin := "http://localhost:5173"
c.Writer.Header().Set("Access-Control-Allow-origin", origin)
c.Writer.Header().Set("Access-control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
c.Writer.Header().Set("Access-control-Allow-Headers", "Origin, Content-Type, Authorization, Accept, User-Agent, Cache-control, pragma")
c.Writer.Header().Set("Access-Control-Allow-Credentials", "true")
c.Writer.Header().Set("Access-Control-Max-Age", "600")
if c.Request.Method == "OPTIONS" {
c.AbortWithStatus(http.StatusNoContent)
return
}
c.Next()
}
}
func main() {
r := gin.Default()
v1 := r.Group("/api/v1")
v1.Use(CORSMiddleware())
v1.POST("/hello", func(c *gin.Context) {
fmt.Println("hello")
c.JSON(200, gin.H{
"message": "Hello World",
})
})
r.Run(":8080")
}
看上去没有问题, 但是前端发送http://localhost:8080/api/v1/hello
请求时报跨域错误, 后端控制台打印出OPTIONS /api/v1/hello 404
. 改成了r.use(CORSMiddleware())
就正常了. 看上去应该经过这个v1组中间件的请求但是没有经过, 感觉有点奇怪.
查看源码
于是我打算看一下gin是怎么工作的. 点进r.Run看一下, 有这个东西
err = http.ListenAndServe(address, engine.Handler())
address
是监听的地址, engine.Handler()
就是各种请求的处理器了. 点进去
func (engine *Engine) Handler() http.Handler {
if !engine.UseH2C {
return engine
}
h2s := &http2.Server{}
return h2c.NewHandler(engine, h2s)
}
http2
又是另一个库了, 这里不需要关心. 这里能使用h2c.NewHandler
说明Engine
实现了http.Handler
接口, 这个接口处理http请求的函数是ServeHTTP(ResponseWriter, *Request)
, 就跟不用gin框架写http服务器一样. 接下来只需要找到Engine
是怎么实现ServeHTTP
就可以了.
点进gin.Default
, 这个函数会返回一个Engine
类型变量, 然后搜索ServeHTTP
func (engine *Engine) ServeHTTP(w http.ResponseWriter, req *http.Request) {
c := engine.pool.Get().(*Context) // 从池子里拿一个出来
c.writermem.reset(w)
c.Request = req
c.reset()
engine.handleHTTPRequest(c)
engine.pool.Put(c)
}
handleHTTPRequest
主要就看handleHTTPRequest
了, 其他的基本都是对context的初始化和结束处理.
func (engine *Engine) handleHTTPRequest(c *Context) {
httpMethod := c.Request.Method // 解析出Method字段
rPath := c.Request.URL.Path // 解析出路径, 如/api/v1/hello
unescape := false
// engine.UseRawPath默认为false
if engine.UseRawPath && len(c.Request.URL.RawPath) > 0 {
rPath = c.Request.URL.RawPath
unescape = engine.UnescapePathValues
}
// 同样默认为false
if engine.RemoveExtraSlash {
rPath = cleanPath(rPath)
}
// 拿出engine的MethodTree数组
t := engine.trees
for i, tl := 0, len(t); i < tl; i++ {
// 匹配方法
if t[i].method != httpMethod {
continue
}
// 取出对应方法的MethodTree
root := t[i].root
// 匹配路径, 取出对应的node
value := root.getValue(rPath, c.params, c.skippedNodes, unescape)
if value.params != nil {
c.Params = *value.params
}
if value.handlers != nil {
c.handlers = value.handlers
c.fullPath = value.fullPath
c.Next()
c.writermem.WriteHeaderNow()
return
}
if httpMethod != http.MethodConnect && rPath != "/" {
// tsr表示跟实际路由相差一个/
// 比如路径是/users, 但是只在/users/设置了handlers, 如果允许尾随斜杠重定向的话就会进行重定向
// 默认允许
if value.tsr && engine.RedirectTrailingSlash {
redirectTrailingSlash(c)
return
}
if engine.RedirectFixedPath && redirectFixedPath(c, root, engine.RedirectFixedPath) {
return
}
}
break
}
// 处理方法不允许的情况
// 默认为false
if engine.HandleMethodNotAllowed {
// According to RFC 7231 section 6.5.5, MUST generate an Allow header field in response
// containing a list of the target resource's currently supported methods.
allowed := make([]string, 0, len(t)-1)
for _, tree := range engine.trees {
if tree.method == httpMethod {
continue
}
if value := tree.root.getValue(rPath, nil, c.skippedNodes, unescape); value.handlers != nil {
allowed = append(allowed, tree.method)
}
}
if len(allowed) > 0 {
c.handlers = engine.allNoMethod
c.writermem.Header().Set("Allow", strings.Join(allowed, ", "))
serveError(c, http.StatusMethodNotAllowed, default405Body)
return
}
}
// 上面匹配都没成功
c.handlers = engine.allNoRoute
serveError(c, http.StatusNotFound, default404Body)
}
这里可以看出来, 当一个请求到来时, gin会先匹配该请求的方法, 再从对应的方法树里面匹配路径. 当方法和路径都没匹配时, 就会进入serveError
中.
之前的OPTIONS http://localhost:8080/api/v1/hello
在这里就没能匹配成功方法, 进入到serveError
中. 继续看serveError
的实现:
func serveError(c *Context, code int, defaultMessage []byte) {
c.writermem.status = code
c.Next()
if c.writermem.Written() {
return
}
if c.writermem.Status() == code {
c.writermem.Header()["Content-Type"] = mimePlain
_, err := c.Writer.Write(defaultMessage)
if err != nil {
debugPrint("cannot write message to writer during serve error: %v", err)
}
return
}
c.writermem.WriteHeaderNow()
}
梳理
这里主要关注c.Next()
. 在写中间件的时候也会使用到, 将context交给下一个Handler处理. 在ServeHTTP
中已经将engine.allNoRoute赋值给c.Handlers了, 那么这里的c.Next()应该会执行c.Handlers的第一个函数(如果有).
看到这里, gin处理一次请求大概清楚了:
后端接收一个请求
- gin首先解析一些数据如Method, Path等
- 根据Method匹配MethodTree(里面有路径及相关Handlers)
- 如果匹配成功, 根据Path匹配node; 否则进入serveError(默认情况下)
- 如果匹配Path成功, 执行对应的Handlers, 否则进入serveError(如果没能执行尾随斜杠重定向和allNoMethod)
进入serveError后会执行engine.allNoRoute的Handlers, 如果期间对c进行了写操作(不包括status, 一般是header或者body), 那么c.writermem.Written()就会为true, 直接返回; 如果没有并且期间修改了status, 那在最后会修改再返回.
了解了这个过程之后就只要找组的中间件和全局中间件是怎么起作用的就可以了.
先点进r.Use
func (engine *Engine) Use(middleware ...HandlerFunc) IRoutes {
engine.RouterGroup.Use(middleware...)
engine.rebuild404Handlers()
engine.rebuild405Handlers()
return engine
}
func (engine *Engine) rebuild404Handlers() {
engine.allNoRoute = engine.combineHandlers(engine.noRoute)
}
func (engine *Engine) rebuild405Handlers() {
engine.allNoMethod = engine.combineHandlers(engine.noMethod)
}
这里其实可以看出来, 全局中间件是会加入到allNoRoute和allNoMethod的.
再看看组的:
func (group *RouterGroup) Use(middleware ...HandlerFunc) IRoutes {
group.Handlers = append(group.Handlers, middleware...)
return group.returnObj()
}
这里只改变了组自身的Handlers, 并没有加入到allNoRoute. 所以OPTIONS http://localhost:8080/api/v1/hello
不会被中间件捕获, 改成全局中间件才可以.
当然, 也可以使用NoRoute
r.NoRoute(middleware.CORSMiddleware())
在默认情况下应该是一样的.
继续查看源码
在实际翻看过程中我没这么顺利, 很多都看不懂. 于是翻了一些其他代码.
Engine和RouterGroup
type RouterGroup struct {
Handlers HandlersChain
basePath string
engine *Engine
root bool
}
RouterGroup
有一个指向engine的指针, 这样在注册路径的时候可以添加到engine的路由树里. 同时engine自己继承了RouterGroup, 使得他也可以注册路径. engine的engine指针指向自己.
func (group *RouterGroup) handle(httpMethod, relativePath string, handlers HandlersChain) IRoutes {
absolutePath := group.calculateAbsolutePath(relativePath)
handlers = group.combineHandlers(handlers)
group.engine.addRoute(httpMethod, absolutePath, handlers)
return group.returnObj()
}
这是他们注册路径的函数, group.calculateAbsolutePath会将group的基路径(一般是'/')和相对路径拼在一起, group.combineHandlers则是将注册的处理器接到group.Handlers后面.
func (engine *Engine) addRoute(method, path string, handlers HandlersChain) {
assert1(path[0] == '/', "path must begin with '/'")
assert1(method != "", "HTTP method can not be empty")
assert1(len(handlers) > 0, "there must be at least one handler")
debugPrintRoute(method, path, handlers)
root := engine.trees.get(method)
if root == nil {
root = new(node)
root.fullPath = "/"
engine.trees = append(engine.trees, methodTree{method: method, root: root})
}
root.addRoute(path, handlers)
if paramsCount := countParams(path); paramsCount > engine.maxParams {
engine.maxParams = paramsCount
}
if sectionsCount := countSections(path); sectionsCount > engine.maxSections {
engine.maxSections = sectionsCount
}
}
engine有一个tree: MethodTrees, 方法树数组, 或者叫方法树森林?get会返回对应方法树的根节点root: *node.
type methodTree struct {
method string
root *node
}
上面是methodTree结构体
type node struct {
path string
indices string
wildChild bool
nType nodeType
priority uint32
children []*node // child nodes, at most 1 :param style node at the end of the array
handlers HandlersChain
fullPath string
}