moira
134 строки · 4.9 Кб
1package middleware
2
3import (
4"context"
5"fmt"
6"net/http"
7"strconv"
8
9"github.com/go-chi/chi"
10"github.com/go-chi/render"
11
12"go.avito.ru/DO/moira"
13"go.avito.ru/DO/moira/api"
14)
15
16func ConfigContext(config api.Config) func(next http.Handler) http.Handler {
17return func(next http.Handler) http.Handler {
18return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
19ctx := context.WithValue(request.Context(), configKey, config)
20next.ServeHTTP(writer, request.WithContext(ctx))
21})
22}
23}
24
25// DatabaseContext sets to requests context configured database
26func DatabaseContext(database moira.Database) func(next http.Handler) http.Handler {
27return func(next http.Handler) http.Handler {
28return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
29ctx := context.WithValue(request.Context(), databaseKey, database)
30next.ServeHTTP(writer, request.WithContext(ctx))
31})
32}
33}
34
35// UserContext get x-webauth-user header and sets it in request context, if header is empty sets empty string
36func UserContext(next http.Handler) http.Handler {
37return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
38userLogin := request.Header.Get("x-webauth-user")
39ctx := context.WithValue(request.Context(), loginKey, userLogin)
40next.ServeHTTP(writer, request.WithContext(ctx))
41})
42}
43
44// TriggerContext gets triggerId from parsed URI corresponding to trigger routes and set it to request context
45func TriggerContext(next http.Handler) http.Handler {
46return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
47triggerID := chi.URLParam(request, "triggerId")
48if triggerID == "" {
49render.Render(writer, request, api.ErrorInvalidRequest(fmt.Errorf("TriggerID must be set")))
50return
51}
52ctx := context.WithValue(request.Context(), triggerIDKey, triggerID)
53next.ServeHTTP(writer, request.WithContext(ctx))
54})
55}
56
57// ContactContext gets contactID from parsed URI corresponding to trigger routes and set it to request context
58func ContactContext(next http.Handler) http.Handler {
59return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
60contactID := chi.URLParam(request, "contactId")
61if contactID == "" {
62render.Render(writer, request, api.ErrorInvalidRequest(fmt.Errorf("ContactID must be set")))
63return
64}
65ctx := context.WithValue(request.Context(), contactIDKey, contactID)
66next.ServeHTTP(writer, request.WithContext(ctx))
67})
68}
69
70// TagContext gets tagName from parsed URI corresponding to tag routes and set it to request context
71func TagContext(next http.Handler) http.Handler {
72return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
73tag := chi.URLParam(request, "tag")
74if tag == "" {
75render.Render(writer, request, api.ErrorInvalidRequest(fmt.Errorf("Tag must be set")))
76return
77}
78ctx := context.WithValue(request.Context(), tagKey, tag)
79next.ServeHTTP(writer, request.WithContext(ctx))
80})
81}
82
83// SubscriptionContext gets subscriptionId from parsed URI corresponding to subscription routes and set it to request context
84func SubscriptionContext(next http.Handler) http.Handler {
85return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
86triggerID := chi.URLParam(request, "subscriptionId")
87if triggerID == "" {
88render.Render(writer, request, api.ErrorInvalidRequest(fmt.Errorf("SubscriptionId must be set")))
89return
90}
91ctx := context.WithValue(request.Context(), subscriptionIDKey, triggerID)
92next.ServeHTTP(writer, request.WithContext(ctx))
93})
94}
95
96// Paginate gets page and size values from URI query and set it to request context. If query has not values sets given values
97func Paginate(defaultPage, defaultSize int64) func(next http.Handler) http.Handler {
98return func(next http.Handler) http.Handler {
99return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
100page, err := strconv.ParseInt(request.URL.Query().Get("p"), 10, 64)
101if err != nil {
102page = defaultPage
103}
104size, err := strconv.ParseInt(request.URL.Query().Get("size"), 10, 64)
105if err != nil {
106size = defaultSize
107}
108
109ctxPage := context.WithValue(request.Context(), pageKey, page)
110ctxSize := context.WithValue(ctxPage, sizeKey, size)
111next.ServeHTTP(writer, request.WithContext(ctxSize))
112})
113}
114}
115
116// DateRange gets from and to values from URI query and set it to request context. If query has not values sets given values
117func DateRange(defaultFrom, defaultTo string) func(next http.Handler) http.Handler {
118return func(next http.Handler) http.Handler {
119return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
120from := request.URL.Query().Get("from")
121if from == "" {
122from = defaultFrom
123}
124to := request.URL.Query().Get("to")
125if to == "" {
126to = defaultTo
127}
128
129ctxPage := context.WithValue(request.Context(), fromKey, from)
130ctxSize := context.WithValue(ctxPage, toKey, to)
131next.ServeHTTP(writer, request.WithContext(ctxSize))
132})
133}
134}
135