首页 「Yogin」实现错误恢复和基本鉴权
文章
取消

「Yogin」实现错误恢复和基本鉴权

从错误中恢复

要实现一个可用的服务器框架,错误恢复是必不可少的。由于代码编写有误或是生产环境中发生了没有预料到的情况,服务器在运行时很可能会发生panic。这个panic往往是服务器处理某个用户发来的请求产生的,我们不希望服务器因为处理一个用户请求出错进而连其它用户都无法响应的局面。因此,我们希望服务器具有从错误中恢复的能力。

panic和recover

Go语言提供了panicrecover机制,实现程序的异常和恢复。在编写程序时,为了让代码中的错误尽可能暴露出来,我们希望用panic的方式检查特定条件是否满足,若不满足,直接用panic终止程序执行。

recover可以捕捉到panic,但只能在defer中捕捉到,在panic前后的语句都无法捕捉到。这是很容易理解的,因为在其之前的recover由于并没有panic产生因此什么也捕捉不到,而在其之后的recover则永远无法执行到,因为panic直接改变了程序的控制流,阻止其之后的语句执行。

Go的运行时会执行完当前发生panic的goroutine中的defer块,在defer块中我们可以使用recover,此时该函数返回panic中的参数。

下面是一些有关panicrecover重要观察1

  • panic只会触发当前goroutine的defer
  • panic允许在defer中嵌套多次调用;
  • recover只有在defer中调用才会生效,且不能越级捕获。

跨协程失效

下面的程序中,main goroutine中的defer无法处理其新建的goroutine中的panic

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
func main() {
    defer func() {
        if r := recover(); r != nil {
            fmt.Printf("in main recover: %v\n", r)
        }
    }()
    go func() {
        defer println("in goroutine")
        panic("goroutine panic")
    }()

    time.Sleep(1 * time.Second)
}
// in goroutine
// panic: goroutine panic

嵌套panic

defer中可以再次抛出panic,并且defer中可以再用一个defer来捕捉这样的panic

在下面程序的输出中,在main函数panic后,3个defer按后入先出的方式执行,由于panic没有全部处理,因此执行完后再打印panic的信息,打印顺序按照panic发生的顺序。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
func main() {
    defer fmt.Println("in main 1")
    
    defer func() {
        if r := recover(); r != nil {
            fmt.Printf("in recover 1: %v\n", r)
        }
        
        defer func() {
            if r := recover(); r != nil {
                fmt.Printf("in recover 2: %v\n", r)
            }
            panic("panic again and again")
        }()
        
        panic("panic again")
    }()
    
    defer fmt.Println("in main 2")
    
    panic("panic once")
}

// in main 2
// in recover 1: panic once
// in recover 2: panic again
// in main 1
//     panic: panic once [recovered]
//     panic: panic again [recovered]
//     panic: panic again and again

使用recover的反例

recover不能越级捕获捕获panic,只有在defer后的匿名函数体中才能成功捕获。下面的代码中,defer里的defer和匿名函数都无法处理panic

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
func main() {
    defer func() {
        defer func() {
            if r := recover(); r != nil {
                fmt.Printf("in defer defer recover: %v\n", r)
            }
        }()

        func() {
            if r := recover(); r != nil {
                fmt.Printf("in defer func recover: %v\n", r)
            }
        }()
    }()
    defer recover()
    panic("panic")
}
// panic: panic

此外,不能用defer recover()捕获同一层的panic,可以简单理解为defer recover()语句会让recover在执行到defer关键字时就会立刻执行,因此,我们可以观察到下面的程序中panic被处理了:

1
2
3
4
5
6
7
func main() {
    defer func() {
        defer recover()
    }()
    panic("panic")
}
// 无报错

更多观察

理解了deferpanicrecover的执行顺序和作用效果,我们可以理解下面的现象:

我们可以在deferpanic,然后在其之前的defer中用recover处理,如果调换二者顺序则失效。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
func main() {
    defer fmt.Println("in main 1")

    defer func() {
        if r := recover(); r != nil {
            fmt.Printf("in recover: %v\n", r)
        }
    }()

    defer func() {
        panic("panic")
    }()

    defer fmt.Println("in main 2")
}
// in main 2
// in recover: panic
// in main 1

defer中多次panic后才recover,只有最后执行的panic信息会被捕捉到,且只处理一次

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
func main() {
    defer func() {
        if r := recover(); r != nil {
            fmt.Printf("in recover 2: %v\n", r)
        }
    }()
    defer func() {
        if r := recover(); r != nil {
            fmt.Printf("in recover 1: %v\n", r)
        }
    }()

    defer func() {
        panic("panic 2")
    }()
    defer func() {
        panic("panic 1")
    }()

    panic("panic 0")
}
// in recover 1: panic 2

错误恢复中间件

在了解go中的错误恢复机制后,我们可以来开始编写错误恢复中间件了。一个错误恢复中间件的主体结构十分简单:

1
2
3
4
5
6
7
8
9
10
11
12
// Recovery returns a middleware that recovers from any panics and writes a 500 if there was one.
func Recovery() HandlerFunc {
    return func(c *Context) {
        defer func() {
            if err := recover(); err != nil {
                // handle recovery ... 
            }
        }()

        c.Next()
    }
}

只需要在其中写一个defer函数处理由c.Next()报出的错误即可。

区分客户端连接断开和服务端出错

服务端报错有两种原因,即客户端连接断开或服务端处理出错。在恢复时,前者的panic不是由于代码编写错误导致的,我们不需要做特别处理,只需要中断后续的中间件流程。后者则需要我们打印调用栈,帮助我们找到报错来源。

我们直接参考gin的实现2

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
// RecoveryFunc defines the function passable to CustomRecovery.
type RecoveryFunc func(c *Context, err interface{})

func defaultHandleRecovery(c *Context, err interface{}) {
    c.AbortWithStatus(http.StatusInternalServerError)
}

// Recovery returns a middleware that recovers from any panics and writes a 500 if there was one.
func Recovery() HandlerFunc {
    logger := log.New(DefaultErrorWriter, "\n\n\x1b[31m", log.LstdFlags)

    return func(c *Context) {
        defer func() {
            if err := recover(); err != nil {
                // Check for a broken connection, as it is not really a
                // condition that warrants a panic stack trace.
                brokenPipe := false
                if ne, ok := err.(*net.OpError); ok {
                    var se *os.SyscallError
                    if errors.As(ne, &se) {
                        if strings.Contains(strings.ToLower(se.Error()), "broken pipe") || strings.Contains(strings.ToLower(se.Error()), "connection reset by peer") {
                            brokenPipe = true
                        }
                    }
                }

                stack := stack(3)
                httpRequest, _ := httputil.DumpRequest(c.Request, false)
                headersToStr := string(httpRequest)
                if brokenPipe {
                    logger.Printf("%s\n%s%s", err, headersToStr, reset)
                } else {
                    logger.Printf("[Recovery] %s panic recovered:\n%s\n%s\n%s%s",
                                time.Now().Format("2006/01/02 - 15:04:05"), headersToStr, err, stack, reset)
                }

                if brokenPipe {
                    // If the connection is dead, we can't write a status to it.
                    c.Error(err.(error))
                    c.Abort()
                } else {
                    defaultHandleRecovery(c, err)
                }
            }
        }()

        c.Next()
    }
}

// stack returns a nicely formatted stack frame, skipping skip frames.
func stack(skip int) string {
    var pcs [32]uintptr
    n := runtime.Callers(skip, pcs[:])

    var b strings.Builder
    b.WriteString("Traceback:")

    frames := runtime.CallersFrames(pcs[:n])
    for {
        frame, more := frames.Next()
        b.WriteString(fmt.Sprintf("\n\t%s:%d", frame.File, frame.Line))
        if !more {
            break
        }
    }
    return b.String()
}

RecoveryFunc作为恢复函数的接口,方便用户自定义恢复行为,我们直接适用默认的defaultHandleRecovery函数,返回客户端500状态码,用AbortWithStatus中断后续中间件流程。使用runtime提供的方法获取调用栈并打印,在gin的recovery.go源码2中,为了使报错信息更加友好,调用栈的打印方法更加复杂,感兴趣的读者可以去研究一下。

对于客户端连接断开的情形,操作系统返回broken pipereset报错,此时我们无须返回客户端响应,因此向上下文中写入错误。

基本鉴权中间件

在完成了日志、恢复等简单的中间件后,我们现在可以尝试实现更为复杂的中间件,来探索上下文Context中提供的Keys字段的作用:保存上游中间件的处理结果,供下游中间件使用。

在实践中,我们往往要求用户在登录后才能访问服务器的某些接口,这些接口往往与用户信息和用户权限是相关的,服务器在处理请求时,需要上下文先提供当前请求的用户名,才能执行后续与特定用户相关的逻辑。

经过分析,我们可以得出基本鉴权中间件的职责:首先验证请求体中是否包含用户已登录的相关参数,然后得到当前登录用户的用户名,使用上下文的Set方法向其中添加用户名,然后将控制流转到后续处理逻辑。在处理逻辑中,我们用上下文的Get方法获取用户名,然后执行与用户相关的逻辑。

若从请求体中推断出用户未登录或登录无效,则阻断后续的请求处理。

建立账号列表

我们给出BasicAuth中间件的基本框架:从请求头的Authorization字段中获取用户的“通行证”(token),在searchCredential函数中验证,该函数会遍历processAccounts函数返回的authPairsauthPair可视为一个由用户名及其“通信证”组成的二元组,“通行证”由用户名和密码经过hash得到。通过与各个账号的authPair逐一比对,确定用户身份。

初始化BasicAuth中间件时,我们传入所有的用户账号accounts,用processAccounts得到这些账号的身份token。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
func BasicAuth(accounts Accounts) HandlerFunc {
    pairs := processAccounts(accounts)
    return func(c *Context) {
        // find and validate auth token, and get username from pairs
        user, found := pairs.searchCredential(c.GetHeader("Authorization"))
        if !found {
            c.Header("WWW-Authenticate", basicRealm)
            c.AbortWithStatus(http.StatusUnauthorized)
            return
        }
        // set username
        c.Set(AuthUserKey, user)
        c.Next()
    }
}

// Accounts defines a key/value for user/pass list of authorized logins.
type Accounts map[string]string

type authPair struct {
    value string
    user  string
}

type authPairs []authPair

func processAccounts(accounts Accounts) authPairs {
    length := len(accounts)
    assert1(length > 0, "Empty list of authorized credentials")

    pairs := make(authPairs, 0, length)
    for user, password := range accounts {
        assert1(user != "", "User can not be empty")

        value := secretHash(user, password)
        pairs = append(pairs, authPair{
            value: value,
            user:  user,
        })
    }
    return pairs
}

“常数时间”字符串比对

在比较用户token是否正确时,我们用到了ConstantTimeCompare3函数:

1
2
3
4
5
6
7
8
9
10
11
func (a authPairs) searchCredential(authValue string) (string, bool) {
    if authValue == "" {
        return "", false
    }
    for _, pair := range a {
        if subtle.ConstantTimeCompare([]byte(pair.value), []byte(authValue)) == 1 {
            return pair.user, true
        }
    }
    return "", false
}

该函数可以缓解针对用户密码验证算法的Timing Attack4。如果直接适用普通的字符串比较函数,这类往往在发现某个位置字符不一致后就返回了,相当于告诉黑客,“你猜的密码前几个字节是正确的”。黑客只要多猜几次,就可以猜出用户的密码。

“常数时间”的字符串比对并意味着这个算法是O(1)的,其执行时间仍然与输入长度相关。该算法做出的保证是,无论两个字符串在哪一位出错,它都会比较到最后一位才返回,因此黑客无法根据算法执行时间推断出自己猜测的密码的正确程度。该算法的一个典型实现是先比较两个串的长度是否相等,然后使用位运算(如异或)检查每个字节是否一致,复杂度为O(n)

示例

让我们来尝试一下本文编写的中间件,完成一个简单的用户登录和调用用户相关API的用例。首先用新增的Default方法创建一个使用日志和恢复中间件的框架示例,向其中添加两条路由。

  • POST /login:在请求表单中传入用户名和密码登录,成功后服务器返回token
  • GET /admin/secrets:利用服务器返回的token,访问用户自己的私有数据。该路由属于/admin路由组,这个路由组中我们使用了BasicAuth中间件。

当然,在实践中我们并不会在表单里用明文发送密码,而往往以base64等编码先对用户的用户名和密码加密。另外,由于HTTP请求是无状态的,服务器往往会对用户新建Session管理用户服务上下文。在下一篇文章中,我们会用CookieSession再次实现一种鉴权方式。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
accounts := Accounts{
    "foo":    "bar",
}
secrets := map[string]interface{}{
    "foo":    map[string]interface{}{"email": "foo@bar.com", "phone": "123433"},
}
const noSecret = "NO SECRET :("
const publicInfo = "PUBLIC INFO :)"

r := Default()

// login interface, return secret token
r.POST("/login", func(c *Context) {
    user := c.PostForm("user")
    password := c.PostForm("password")
    if pwd, ok := accounts[user]; ok {
        if password == pwd {
            c.String(http.StatusOK, secretHash(user, password))
            return
        }
    }
    c.Forbidden().WithString("wrong user or password")
})

// routes under authorized group requires user to login first
authorized := r.Group("/admin", BasicAuth(accounts))

authorized.GET("/secrets", func(c *Context) {
    // get user, it was set by the BasicAuth middleware
    user := c.MustGet(AuthUserKey).(string)
    if secret, ok := secrets[user]; ok {
        c.OK().WithJSON(H{"user": user, "secret": secret})
    } else {
        c.OK().WithJSON(H{"user": user, "secret": noSecret})
    }
})

// test legal users
user := "foo"
password := "bar"
var token string
{
    req := httptest.NewRequest(http.MethodPost, "/login", nil)
    req.PostForm = make(url.Values)
    req.PostForm["user"] = []string{user}
    req.PostForm["password"] = []string{password}
    w := httptest.NewRecorder()
    r.ServeHTTP(w, req)
    token = w.Body.String()
}

{
    req := httptest.NewRequest(http.MethodGet, "/admin/secrets", nil)
    req.Header.Set("Authorization", token)
    w := httptest.NewRecorder()
    r.ServeHTTP(w, req)
}

post03_test.go中,我提供了更详细的测试用例,包括对服务器panic的测试。此外,在鉴权测试中,还演示了未登录(登录失败)用户访问/admin下的API被阻断,以及所有用户都可访问/publicAPI的场景。

完整代码仓库

yogin

「Yogin」系列全部代码可在我的GitHub代码仓库中查看:Yogin is Your Own Gin

欢迎提出各类宝贵的修改意见和issues,指出其中的错误和不足!

最后,感谢你读到这里,希望我们都有所收获!

References

本文由作者按照 CC BY 4.0 进行授权

「Yogin」实现分组路由和日志中间件

「Yogin」更多功能:模板渲染、限流和Session管理