ref: master
db/middleware.go
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 |
package db import ( "context" "database/sql" "errors" "net/http" ) var dbCtxKey = &contextKey{"db"} type contextKey struct { name string } func Middleware(db *sql.DB) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { r = r.WithContext(Context(r.Context(), db)) next.ServeHTTP(w, r) }) } } func Context(ctx context.Context, db *sql.DB) context.Context { return context.WithValue(ctx, dbCtxKey, db) } func ForContext(ctx context.Context) (*sql.DB) { raw, ok := ctx.Value(dbCtxKey).(*sql.DB) if !ok { panic(errors.New("Invalid db context")) } return raw } |