go-transaction-manager
138 строк · 2.8 Кб
1package sql
2
3import (
4"context"
5"database/sql"
6"fmt"
7"sync/atomic"
8
9"go.uber.org/multierr"
10
11"github.com/avito-tech/go-transaction-manager/trm/v2"
12"github.com/avito-tech/go-transaction-manager/trm/v2/drivers"
13)
14
15// Transaction is trm.Transaction for sql.Tx.
16// trm.NestedTrFactory returns IsActive as true while trm.Transaction is opened.
17type Transaction struct {
18tx *sql.Tx
19savePoint SavePoint
20saves int64
21isClosed *drivers.IsClosed
22}
23
24// NewTransaction creates trm.Transaction for sql.Tx.
25func NewTransaction(
26ctx context.Context,
27sp SavePoint,
28opts *sql.TxOptions,
29db *sql.DB,
30) (context.Context, *Transaction, error) {
31tx, err := db.BeginTx(ctx, opts)
32if err != nil {
33return ctx, nil, err
34}
35
36tr := &Transaction{
37tx: tx,
38savePoint: sp,
39saves: 0,
40isClosed: drivers.NewIsClosed(),
41}
42
43go tr.awaitDone(ctx)
44
45return ctx, tr, nil
46}
47
48func (t *Transaction) awaitDone(ctx context.Context) {
49if ctx.Done() == nil {
50return
51}
52
53select {
54case <-ctx.Done():
55t.isClosed.Close()
56case <-t.isClosed.Closed():
57}
58}
59
60// Transaction returns the real transaction sqlx.Tx.
61func (t *Transaction) Transaction() interface{} {
62return t.tx
63}
64
65// Begin nested transaction by save point.
66func (t *Transaction) Begin(ctx context.Context, _ trm.Settings) (context.Context, trm.Transaction, error) {
67_, err := t.tx.ExecContext(ctx, t.savePoint.Create(t.incrementID()))
68if err != nil {
69// decrement save point ID after error
70t.decrementID()
71
72return ctx, nil, err
73}
74
75return ctx, t, nil
76}
77
78// Commit the trm.Transaction.
79func (t *Transaction) Commit(ctx context.Context) error {
80if t.hasSavePoint() {
81_, err := t.tx.ExecContext(ctx, t.savePoint.Release(t.decrementID()))
82if err != nil {
83return multierr.Combine(trm.ErrNestedCommit, err)
84}
85
86return nil
87}
88
89defer t.isClosed.Close()
90
91return t.tx.Commit()
92}
93
94// Rollback the trm.Transaction.
95func (t *Transaction) Rollback(ctx context.Context) error {
96if t.hasSavePoint() {
97_, err := t.tx.ExecContext(ctx, t.savePoint.Rollback(t.decrementID()))
98if err != nil {
99return multierr.Combine(trm.ErrNestedRollback, err)
100}
101
102return nil
103}
104
105defer t.isClosed.Close()
106
107return t.tx.Rollback()
108}
109
110// IsActive returns true if the transaction started but not committed or rolled back.
111func (t *Transaction) IsActive() bool {
112return t.isClosed.IsActive()
113}
114
115// Closed returns a channel that's closed when transaction committed or rolled back.
116func (t *Transaction) Closed() <-chan struct{} {
117return t.isClosed.Closed()
118}
119
120func (t *Transaction) hasSavePoint() bool {
121return atomic.LoadInt64(&t.saves) > 0
122}
123
124func (t *Transaction) incrementID() string {
125atomic.AddInt64(&t.saves, 1)
126
127return t.id()
128}
129
130func (t *Transaction) decrementID() string {
131defer atomic.AddInt64(&t.saves, -1)
132
133return t.id()
134}
135
136func (t *Transaction) id() string {
137return fmt.Sprintf("tx_%d", atomic.LoadInt64(&t.saves))
138}
139