kraken
157 строк · 4.1 Кб
1// Copyright (c) 2016-2019 Uber Technologies, Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14package middleware
15
16import (
17"fmt"
18"io"
19"net/http"
20"testing"
21"time"
22
23"github.com/uber/kraken/utils/httputil"
24"github.com/uber/kraken/utils/testutil"
25
26"github.com/go-chi/chi"
27"github.com/stretchr/testify/require"
28"github.com/uber-go/tally"
29)
30
31func TestScopeByEndpoint(t *testing.T) {
32tests := []struct {
33method string
34path string
35reqPath string
36expectedEndpoint string
37}{
38{"GET", "/foo/{foo}/bar/{bar}", "/foo/x/bar/y", "foo.bar"},
39{"POST", "/foo/{foo}/bar/{bar}", "/foo/x/bar/y", "foo.bar"},
40{"GET", "/a/b/c", "/a/b/c", "a.b.c"},
41{"GET", "/", "/", ""},
42{"GET", "/x/{a}/{b}/{c}", "/x/a/b/c", "x"},
43}
44
45for _, test := range tests {
46t.Run(test.method+" "+test.path, func(t *testing.T) {
47require := require.New(t)
48
49stats := tally.NewTestScope("", nil)
50
51r := chi.NewRouter()
52r.HandleFunc(test.path, func(w http.ResponseWriter, r *http.Request) {
53tagEndpoint(stats, r).Counter("count").Inc(1)
54})
55addr, stop := testutil.StartServer(r)
56defer stop()
57
58_, err := httputil.Send(test.method, fmt.Sprintf("http://%s%s", addr, test.reqPath))
59require.NoError(err)
60
61require.Equal(1, len(stats.Snapshot().Counters()))
62for _, v := range stats.Snapshot().Counters() {
63require.Equal("count", v.Name())
64require.Equal(int64(1), v.Value())
65require.Equal(map[string]string{
66"endpoint": test.expectedEndpoint,
67"method": test.method,
68}, v.Tags())
69}
70})
71}
72}
73
74func TestLatencyTimer(t *testing.T) {
75require := require.New(t)
76
77stats := tally.NewTestScope("", nil)
78
79r := chi.NewRouter()
80r.Use(LatencyTimer(stats))
81r.Get("/foo/{foo}", func(w http.ResponseWriter, r *http.Request) {
82time.Sleep(200 * time.Millisecond)
83})
84
85addr, stop := testutil.StartServer(r)
86defer stop()
87
88_, err := httputil.Get(fmt.Sprintf("http://%s/foo/x", addr))
89require.NoError(err)
90
91now := time.Now()
92
93require.Equal(1, len(stats.Snapshot().Timers()))
94for _, v := range stats.Snapshot().Timers() {
95require.Equal("latency", v.Name())
96require.WithinDuration(now, now.Add(v.Values()[0]), 500*time.Millisecond)
97require.Equal(map[string]string{
98"endpoint": "foo",
99"method": "GET",
100}, v.Tags())
101}
102}
103
104func TestStatusCounter(t *testing.T) {
105tests := []struct {
106desc string
107handler func(http.ResponseWriter, *http.Request)
108expectedStatus string
109}{
110{
111"empty handler counts 200",
112func(http.ResponseWriter, *http.Request) {},
113"200",
114}, {
115"writes count 200",
116func(w http.ResponseWriter, _ *http.Request) { io.WriteString(w, "OK") },
117"200",
118}, {
119"write header",
120func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(500) },
121"500",
122}, {
123"multiple write header calls only measures first call",
124func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(400); w.WriteHeader(500) },
125"400",
126},
127}
128for _, test := range tests {
129t.Run(test.desc, func(t *testing.T) {
130require := require.New(t)
131
132stats := tally.NewTestScope("", nil)
133
134r := chi.NewRouter()
135r.Use(StatusCounter(stats))
136r.Get("/foo/{foo}", test.handler)
137
138addr, stop := testutil.StartServer(r)
139defer stop()
140
141for i := 0; i < 5; i++ {
142_, err := http.Get(fmt.Sprintf("http://%s/foo/x", addr))
143require.NoError(err)
144}
145
146require.Equal(1, len(stats.Snapshot().Counters()))
147for _, v := range stats.Snapshot().Counters() {
148require.Equal(test.expectedStatus, v.Name())
149require.Equal(int64(5), v.Value())
150require.Equal(map[string]string{
151"endpoint": "foo",
152"method": "GET",
153}, v.Tags())
154}
155})
156}
157}
158