summaryrefslogtreecommitdiff
path: root/internal/queries/dbtx.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/queries/dbtx.go')
-rw-r--r--internal/queries/dbtx.go44
1 files changed, 44 insertions, 0 deletions
diff --git a/internal/queries/dbtx.go b/internal/queries/dbtx.go
new file mode 100644
index 0000000..ba96dfb
--- /dev/null
+++ b/internal/queries/dbtx.go
@@ -0,0 +1,44 @@
+package queries
+
+import (
+ "context"
+ "database/sql"
+ "errors"
+ "fmt"
+)
+
+func (q *Queries) DBTX() DBTX {
+ return q.db
+}
+
+func (q *Queries) InTx(
+ ctx context.Context,
+ opts *sql.TxOptions,
+ body func(context.Context, *Queries) error,
+) error {
+ var tx *sql.Tx
+ var err error
+
+ switch db := q.db.(type) {
+ case *sql.Tx:
+ return body(ctx, q)
+ case interface {
+ BeginTx(context.Context, *sql.TxOptions) (*sql.Tx, error)
+ }:
+ tx, err = db.BeginTx(ctx, opts)
+ if err != nil {
+ return fmt.Errorf("Queries.InTx failed to create tx: %w", err)
+ }
+ defer func() {
+ if err == nil {
+ _ = tx.Commit()
+ } else {
+ _ = tx.Rollback()
+ }
+ }()
+ err = body(ctx, New(tx))
+ return err
+ default:
+ return errors.New("Queries.InTx: invalid DBTX type")
+ }
+}