Author: Pedro Lucas Porcellis <porcellis@eletrotupi.com>
database: introduce a very basic middleware
database/middleware.go | 70 ++++++++++++++++++++++++++++++++++++++++++++
diff --git a/database/middleware.go b/database/middleware.go new file mode 100644 index 0000000000000000000000000000000000000000..2270eb661fab4f2d41e0d93e83231004bc50e8cb --- /dev/null +++ b/database/middleware.go @@ -0,0 +1,70 @@ +package database + +import ( + "context" + "database/sql" + "errors" + "net/http" +) + +var dbCtxKey = &contextKey{"database"} + +type contextKey struct { + name string +} + +func Middleware(db *sql.DB) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := Context(r.Context(), db) + r = r.WithContext(ctx) + next.ServeHTTP(w, r) + }) + } +} + +func Context(ctx context.Context, db *sql.DB) context.Context { + return context.WithValue(ctx, dbCtxKey, db) +} + +func ForContext(ctx context.Context) (*sql.Conn, error) { + raw, ok := ctx.Value(dbCtxKey).(*sql.DB) + if !ok { + panic(errors.New("Invalid database context")) + } + return raw.Conn(ctx) +} + +func WithTx(ctx context.Context, opts *sql.TxOptions, fn func(tx *sql.Tx) error) error { + conn, err := ForContext(ctx) + if err != nil { + return err + } + defer conn.Close() + + tx, err := conn.BeginTx(ctx, opts) + if err != nil { + return err + } + + defer func() { + if r := recover(); r != nil { + tx.Rollback() + panic(r) + } + }() + err = fn(tx) + + if err != nil { + err := tx.Rollback() + if err != nil && err != sql.ErrTxDone { + panic(err) + } + } else { + err := tx.Commit() + if err != nil && err != sql.ErrTxDone { + panic(err) + } + } + return err +}