go-transaction-manager
49 строк · 1.2 Кб
1package gorm
2
3import (
4"context"
5
6"gorm.io/gorm"
7
8"github.com/avito-tech/go-transaction-manager/trm/v2"
9trmcontext "github.com/avito-tech/go-transaction-manager/trm/v2/context"
10)
11
12// DefaultCtxGetter is the CtxGetter with settings.DefaultCtxKey.
13var DefaultCtxGetter = NewCtxGetter(trmcontext.DefaultManager)
14
15// CtxGetter gets Tr from trm.СtxManager by casting trm.Transaction to *gorm.DB.
16type CtxGetter struct {
17ctxManager trm.СtxManager
18}
19
20// NewCtxGetter returns *CtxGetter to get *gorm.DB from context.Context.
21func NewCtxGetter(c trm.СtxManager) *CtxGetter {
22return &CtxGetter{ctxManager: c}
23}
24
25// DefaultTrOrDB returns Tr(*gorm.DB) from context.Context or DB(*gorm.DB) otherwise.
26func (c *CtxGetter) DefaultTrOrDB(ctx context.Context, db *gorm.DB) *gorm.DB {
27if tr := c.ctxManager.Default(ctx); tr != nil {
28return c.convert(tr)
29}
30
31return db
32}
33
34// TrOrDB returns Tr(*gorm.DB) from context.Context by trm.CtxKey or DB(*gorm.DB) otherwise.
35func (c *CtxGetter) TrOrDB(ctx context.Context, key trm.CtxKey, db *gorm.DB) *gorm.DB {
36if tr := c.ctxManager.ByKey(ctx, key); tr != nil {
37return c.convert(tr)
38}
39
40return db
41}
42
43func (c *CtxGetter) convert(tr trm.Transaction) *gorm.DB {
44if tx, ok := tr.Transaction().(*gorm.DB); ok {
45return tx
46}
47
48return nil
49}
50