package queries import ( "context" "database/sql" "errors" "fmt" ) func (q *Queries) DBTX() DBTX { return q.db } func (q *Queries) InTx( ctx context.Context, opts *sql.TxOptions, body func(context.Context, *Queries) error, ) error { var tx *sql.Tx var err error switch db := q.db.(type) { case *sql.Tx: return body(ctx, q) case interface { BeginTx(context.Context, *sql.TxOptions) (*sql.Tx, error) }: tx, err = db.BeginTx(ctx, opts) if err != nil { return fmt.Errorf("Queries.InTx failed to create tx: %w", err) } defer func() { if err == nil { _ = tx.Commit() } else { _ = tx.Rollback() } }() err = body(ctx, New(tx)) return err default: return errors.New("Queries.InTx: invalid DBTX type") } }