ref: master
database/middleware.go
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
package database import ( "context" "database/sql" "errors" "net/http" ) var dbCtxKey = &contextKey{"database"} type contextKey struct { name string } func Middleware(db *sql.DB) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ctx := Context(r.Context(), db) r = r.WithContext(ctx) next.ServeHTTP(w, r) }) } } func Context(ctx context.Context, db *sql.DB) context.Context { return context.WithValue(ctx, dbCtxKey, db) } func ForContext(ctx context.Context) (*sql.Conn, error) { raw, ok := ctx.Value(dbCtxKey).(*sql.DB) if !ok { panic(errors.New("Invalid database context")) } return raw.Conn(ctx) } func WithTx(ctx context.Context, opts *sql.TxOptions, fn func(tx *sql.Tx) error) error { conn, err := ForContext(ctx) if err != nil { return err } defer conn.Close() tx, err := conn.BeginTx(ctx, opts) if err != nil { return err } defer func() { if r := recover(); r != nil { tx.Rollback() panic(r) } }() err = fn(tx) if err != nil { err := tx.Rollback() if err != nil && err != sql.ErrTxDone { panic(err) } } else { err := tx.Commit() if err != nil && err != sql.ErrTxDone { panic(err) } } return err } |