英文:
Rate Limiting HTTP Requests (via http.HandlerFunc middleware)
问题
我正在寻找编写一个小型的速率限制中间件,它具有以下功能:
- 允许我为每个远程 IP 设置一个合理的速率(比如每秒 10 次请求)。
- 可能(但不一定)允许突发请求。
- 关闭超过速率限制的连接,并返回 HTTP 429 错误。
然后,我可以将其应用于身份验证路由或其他可能容易受到暴力破解攻击的路由(例如使用过期令牌的密码重置 URL 等)。虽然有人通过暴力破解来猜测一个 16 或 24 字节的令牌的几率非常低,但多一层保护也无妨。
我已经查看了 https://code.google.com/p/go-wiki/wiki/RateLimiting,但不确定如何与 http.Request(s) 结合使用。此外,我不确定如何在一段时间内“跟踪”来自特定 IP 的请求。
理想情况下,我希望得到类似下面的代码,注意我在反向代理(nginx)后面,所以我们要检查 REMOTE_ADDR
HTTP 头而不是使用 r.RemoteAddr
:
// 速率限制中间件
func rateLimit(h http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
remoteIP := r.Header.Get("REMOTE_ADDR")
for req := range (这里是什么?) {
// 这里是什么?
// 如果超过限制,设置 w.WriteHeader(429) 并关闭请求
// 否则传递给链中的下一个处理程序
h.ServeHTTP(w, r)
}
}
// 示例路由
r.HandleFunc("/login", use(loginForm, rateLimit, csrf))
r.HandleFunc("/form", use(editHandler, rateLimit, csrf))
// 中间件包装器,用于上下文
func use(h http.HandlerFunc, middleware ...func(http.HandlerFunc) http.HandlerFunc) http.HandlerFunc {
for _, m := range middleware {
h = m(h)
}
return h
}
我希望能得到一些指导。
英文:
I'm looking to write a small piece of rate-limiting middleware that:
- Allows me to set a sensible rate (say, 10 req/s) per remote IP
- Possibly (but it doesn't have to) allow for bursts
- Drops (closes?) connections that exceed the rate and returns a HTTP 429
I can then wrap this around authentication routes or other routes that might be vulnerable to brute-force attacks (i.e. password reset URLs using a token that expires, etc.). The chances of someone brute forcing a 16 or 24 byte token are really low, but it doesn't hurt to go that extra step.
I've had a look at https://code.google.com/p/go-wiki/wiki/RateLimiting but am not sure how to reconcile it with http.Request(s). Further, I'm not sure how we'd "track" requests from a given IP over any period of time.
Ideally I'd end up with something like this, noting that I'm behind a reverse proxy (nginx) so we're checking for the REMOTE_ADDR
HTTP header rather than using r.RemoteAddr
:
// Rate-limiting middleware
func rateLimit(h http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
remoteIP := r.Header.Get("REMOTE_ADDR")
for req := range (what here?) {
// what here?
// w.WriteHeader(429) and close the request if it exceeds the limit
// else pass to the next handler in the chain
h.ServeHTTP(w, r)
}
}
// Example routes
r.HandleFunc("/login", use(loginForm, rateLimit, csrf)
r.HandleFunc("/form", use(editHandler, rateLimit, csrf)
// Middleware wrapper, for context
func use(h http.HandlerFunc, middleware ...func(http.HandlerFunc) http.HandlerFunc) http.HandlerFunc {
for _, m := range middleware {
h = m(h)
}
return h
}
I'd appreciate some guidance here.
答案1
得分: 11
你提供的示例是一个通用的速率限制示例。它使用range函数,因为它通过通道获取请求。
对于HTTP请求来说,情况有所不同,但这里没有什么复杂的东西。请注意,你不会遍历请求的通道,或者其他什么——你的HandlerFunc会为每个单独的传入请求调用。
现在,选择存储速率限制计数器的位置取决于你。一种解决方案是简单地使用一个全局映射(不要忘记进行安全并发访问),将IP映射到它们的请求计数器。然而,你需要注意请求是多久之前发出的。
Sergio建议使用Redis。它的键值特性非常适合这样的简单结构,而且你还可以免费获得过期功能。
英文:
The rate limiting example you've linked to is a general one. It uses range because it gets requests over a channel.
It's a different story with HTTP requests, but there's nothing really complicated here. Note that you don't iterate over a channel of requests, or anything -- your HandlerFunc is called for every incoming request separately.
func rateLimit(h http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
remoteIP := r.Header.Get("REMOTE_ADDR")
if exceededTheLimit(remoteIP) {
w.WriteHeader(429)
// it then returns, not passing the request down the chain
} else {
h.ServeHTTP(w, r);
}
}
}
Now, choosing the place to store the rate limit counters is up to you. One solution would be to simply use a global map (don't forget safe concurrent access) that would map IPs to their request counters. However, you would have to be aware of how long ago the requests were made.
Sergio suggested using Redis. Its key-value nature is a perfect fit for simple structures like this and you get expiration for free.
答案2
得分: 4
你可以将数据存储在Redis中。这里有一个非常有用的命令,甚至在其文档中提到了速率限制应用程序:INCR。Redis还会处理旧数据的清理(通过过期旧键)。
此外,使用Redis作为速率限制器存储,您可以使用多个共享此中央存储的前端进程。
有人会争辩说,每次都去外部进程是昂贵的。但是密码重置页面并不是绝对需要最佳性能的页面。而且,如果将Redis放在同一台机器上,延迟应该会很低。
英文:
You could store the data in redis. Here's a very useful command that even mentions rate limiting application in its documentation: INCR. Redis will also handle cleanup of old data (via expiration of old keys).
Also, with redis being the rate limiter storage, you can use multiple frontend processes that share this central storage.
Some would argue that going to external process each time is expensive. But password reset page is not a kind of page that absolutely demands best performance. Also, if you place the redis on the same machine, latency should be pretty low.
答案3
得分: 4
我今天早上做了一些简单而类似的事情,我认为它可能对你的情况有帮助。
package main
import (
"log"
"net/http"
"strings"
"time"
)
func main() {
fs := http.FileServer(http.Dir("./html/"))
http.Handle("/", fs)
log.Println("Listening..")
go clearLastRequestsIPs()
go clearBlockedIPs()
err := http.ListenAndServe(":8080", middleware(nil))
if err != nil {
log.Fatalln(err)
}
}
// 存储最近请求的IP地址
var lastRequestsIPs []string
// 封锁IP地址6小时
var blockedIPs []string
func middleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ipAddr := strings.Split(r.RemoteAddr, ":")[0]
if existsBlockedIP(ipAddr) {
http.Error(w, "", http.StatusTooManyRequests)
return
}
// 当前IP地址在过去5分钟内发起的请求数量
requestCounter := 0
for _, ip := range lastRequestsIPs {
if ip == ipAddr {
requestCounter++
}
}
if requestCounter >= 1000 {
blockedIPs = append(blockedIPs, ipAddr)
http.Error(w, "", http.StatusTooManyRequests)
return
}
lastRequestsIPs = append(lastRequestsIPs, ipAddr)
// 不中断中间件链
if next == nil {
http.DefaultServeMux.ServeHTTP(w, r)
return
}
next.ServeHTTP(w, r)
})
}
func existsBlockedIP(ipAddr string) bool {
for _, ip := range blockedIPs {
if ip == ipAddr {
return true
}
}
return false
}
func existsLastRequest(ipAddr string) bool {
for _, ip := range lastRequestsIPs {
if ip == ipAddr {
return true
}
}
return false
}
// 每5分钟清空lastRequestsIPs数组
func clearLastRequestsIPs() {
for {
lastRequestsIPs = []string{}
time.Sleep(time.Minute * 5)
}
}
// 每6小时清空blockedIPs数组
func clearBlockedIPs() {
for {
blockedIPs = []string{}
time.Sleep(time.Hour * 6)
}
}
这个例子还不够精确,但它可以作为一个简单的速率限制器示例。你可以通过添加请求路径、HTTP方法甚至身份验证作为判断流量是否是攻击的因素来改进它。
英文:
I have done something simple and similar this morning, I think it could help your case.
package main
import (
"log"
"net/http"
"strings"
"time"
)
func main() {
fs := http.FileServer(http.Dir("./html/"))
http.Handle("/", fs)
log.Println("Listening..")
go clearLastRequestsIPs()
go clearBlockedIPs()
err := http.ListenAndServe(":8080", middleware(nil))
if err != nil {
log.Fatalln(err)
}
}
// Stores last requests IPs
var lastRequestsIPs []string
// Block IP for 6 hours
var blockedIPs []string
func middleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ipAddr := strings.Split(r.RemoteAddr, ":")[0]
if existsBlockedIP(ipAddr) {
http.Error(w, "", http.StatusTooManyRequests)
return
}
// how many requests the current IP made in last 5 mins
requestCounter := 0
for _, ip := range lastRequestsIPs {
if ip == ipAddr {
requestCounter++
}
}
if requestCounter >= 1000 {
blockedIPs = append(blockedIPs, ipAddr)
http.Error(w, "", http.StatusTooManyRequests)
return
}
lastRequestsIPs = append(lastRequestsIPs, ipAddr)
// Don't cut the chain of middlewares
if next == nil {
http.DefaultServeMux.ServeHTTP(w, r)
return
}
next.ServeHTTP(w, r)
})
}
func existsBlockedIP(ipAddr string) bool {
for _, ip := range blockedIPs {
if ip == ipAddr {
return true
}
}
return false
}
func existsLastRequest(ipAddr string) bool {
for _, ip := range lastRequestsIPs {
if ip == ipAddr {
return true
}
}
return false
}
// Clears lastRequestsIPs array every 5 mins
func clearLastRequestsIPs() {
for {
lastRequestsIPs = []string{}
time.Sleep(time.Minute * 5)
}
}
// Clears blockedIPs array every 6 hours
func clearBlockedIPs() {
for {
blockedIPs = []string{}
time.Sleep(time.Hour * 6)
}
}
It's still not precise yet, however, it would help as a simple example of rate limiter. you can improve it by adding requested path, http method and even authentication as factors to decide whether the flow is an attack or not.
答案4
得分: 2
这是我的速率限制中间件实现。它可以作为全局速率限制器或单个请求的速率限制器,非常好用。我在我的应用程序中广泛使用它。
以下是它的特点:
- 无需外部依赖
- 可测试
- 可配置
- 添加头部信息,以便客户端了解在达到限制之前还剩下多少请求等
- 自动删除过期数据
首先,是实现部分:
r := router.New()
stats := stats.New()
r.With(middleware.RateLimit(1, time.Minute * 1, stats)).Post("/contact", c.Contact)
上述中间件将允许每分钟进行一次请求,当向"/contact"发起POST
请求时。
以下是中间件的代码:
package middleware
import (
"net/http"
"strconv"
"time"
)
// Stats 是对底层哈希表/映射数据结构的接口。可以根据需要进行实现。
type Stats interface {
// Reset 将重置映射。
Reset()
// Add 将“count”添加到键“identifier”对应的映射中,并返回该键对应的值的总数。
Add(identifier string, count int) int
}
// RateLimit 中间件是一个通用的速率限制器,可以在任何场景中使用,因为它允许对每个特定请求进行细粒度的速率限制。或者您可以在整个路由组上设置速率限制器。它只是一个 HandlerFunc。
func RateLimit(limit int, window time.Duration, stats Stats) func(next http.Handler) http.Handler {
var windowStart time.Time
// 在每个窗口之后清除速率限制统计信息。
ticker := time.NewTicker(window)
go func() {
windowStart = time.Now()
for range ticker.C {
windowStart = time.Now()
stats.Reset()
}
}()
return func(next http.Handler) http.Handler {
h := func(w http.ResponseWriter, r *http.Request) {
value := int(stats.Add(identifyRequest(r), 1))
XRateLimitRemaining := limit - value
if XRateLimitRemaining < 0 {
XRateLimitRemaining = 0
}
w.Header().Add("X-Rate-Limit-Limit", strconv.Itoa(limit))
w.Header().Add("X-Rate-Limit-Remaining", strconv.Itoa(XRateLimitRemaining))
w.Header().Add("X-Rate-Limit-Reset", strconv.Itoa(int(window.Seconds()-time.Since(windowStart).Seconds())+1))
if value >= limit {
w.WriteHeader(429)
// 做其他操作...
} else {
next.ServeHTTP(w, r)
}
}
return http.HandlerFunc(h)
}
}
// identifyRequest 从请求上下文中获取标识符。
func identifyRequest(r *http.Request) string {
// 在这里识别您的请求(获取IP地址等)。
}
英文:
Here's my rate limit middleware implementation. It works very nicely as a global rate limiter, or a rate limiter for an individual request. I use it extensively in my apps.
Here is what you get with it:
- no external dependencies
- testable
- configurable
- adds headers so a client can understand how many requests that have left before they are limited, etc.
- automatically removes expired data.
First, the implementation:
r := router.New()
stats := stats.New()
r.With(middleware.RateLimit(1, time.Minute * 1, stats)).Post("/contact", c.Contact)
The middleware about will allow one request pet minute when making a POST
request to /contact
.
Here is the middleware:
package middleware
import (
"net/http"
"strconv"
"time"
)
// Stats is an interface to an underlying hash table/map data
// structure. Implement it however you'd like.
type Stats interface {
// Reset will reset the map.
Reset()
// Add would add "count" to the map at the key of "identifier",
// and returns an int which is the total count of the value
// at that key.
Add(identifier string, count int) int
}
// RateLimit middleware is a generic rate limiter that can be used in any scenario
// because it allows granular rate limiting for each specific request. Or you can
// set the rate limiter on the entire router group. It's just a HandlerFunc.
func RateLimit(limit int, window time.Duration, stats Stats) func(next http.Handler) http.Handler {
var windowStart time.Time
// Clear the rate limit stats after each window.
ticker := time.NewTicker(window)
go func() {
windowStart = time.Now()
for range ticker.C {
windowStart = time.Now()
stats.Reset()
}
}()
return func(next http.Handler) http.Handler {
h := func(w http.ResponseWriter, r *http.Request) {
value := int(stats.Add(identifyRequest(r), 1))
XRateLimitRemaining := limit - value
if XRateLimitRemaining < 0 {
XRateLimitRemaining = 0
}
w.Header().Add("X-Rate-Limit-Limit", strconv.Itoa(limit))
w.Header().Add("X-Rate-Limit-Remaining", strconv.Itoa(XRateLimitRemaining))
w.Header().Add("X-Rate-Limit-Reset", strconv.Itoa(int(window.Seconds()-time.Since(windowStart).Seconds())+1))
if value >= limit {
w.WriteHeader(429)
// Do something else...
} else {
next.ServeHTTP(w, r)
}
}
return http.HandlerFunc(h)
}
}
// identifyRequest gets an identifier from the request context.
func identifyRequest(r *http.Request) string {
// Identify your request here (get IP address, etc.)
}
通过集体智慧和协作来改善编程学习和解决问题的方式。致力于成为全球开发者共同参与的知识库,让每个人都能够通过互相帮助和分享经验来进步。
评论