diff options
Diffstat (limited to 'internal/queries/dbtx.go')
-rw-r--r-- | internal/queries/dbtx.go | 44 |
1 files changed, 44 insertions, 0 deletions
diff --git a/internal/queries/dbtx.go b/internal/queries/dbtx.go new file mode 100644 index 0000000..ba96dfb --- /dev/null +++ b/internal/queries/dbtx.go @@ -0,0 +1,44 @@ +package queries + +import ( + "context" + "database/sql" + "errors" + "fmt" +) + +func (q *Queries) DBTX() DBTX { + return q.db +} + +func (q *Queries) InTx( + ctx context.Context, + opts *sql.TxOptions, + body func(context.Context, *Queries) error, +) error { + var tx *sql.Tx + var err error + + switch db := q.db.(type) { + case *sql.Tx: + return body(ctx, q) + case interface { + BeginTx(context.Context, *sql.TxOptions) (*sql.Tx, error) + }: + tx, err = db.BeginTx(ctx, opts) + if err != nil { + return fmt.Errorf("Queries.InTx failed to create tx: %w", err) + } + defer func() { + if err == nil { + _ = tx.Commit() + } else { + _ = tx.Rollback() + } + }() + err = body(ctx, New(tx)) + return err + default: + return errors.New("Queries.InTx: invalid DBTX type") + } +} |