podman
352 строки · 9.4 Кб
1package handlers
2
3import (
4"net/http"
5"strconv"
6"strings"
7)
8
9// CORSOption represents a functional option for configuring the CORS middleware.
10type CORSOption func(*cors) error
11
12type cors struct {
13h http.Handler
14allowedHeaders []string
15allowedMethods []string
16allowedOrigins []string
17allowedOriginValidator OriginValidator
18exposedHeaders []string
19maxAge int
20ignoreOptions bool
21allowCredentials bool
22optionStatusCode int
23}
24
25// OriginValidator takes an origin string and returns whether or not that origin is allowed.
26type OriginValidator func(string) bool
27
28var (
29defaultCorsOptionStatusCode = http.StatusOK
30defaultCorsMethods = []string{http.MethodGet, http.MethodHead, http.MethodPost}
31defaultCorsHeaders = []string{"Accept", "Accept-Language", "Content-Language", "Origin"}
32// (WebKit/Safari v9 sends the Origin header by default in AJAX requests).
33)
34
35const (
36corsOptionMethod string = http.MethodOptions
37corsAllowOriginHeader string = "Access-Control-Allow-Origin"
38corsExposeHeadersHeader string = "Access-Control-Expose-Headers"
39corsMaxAgeHeader string = "Access-Control-Max-Age"
40corsAllowMethodsHeader string = "Access-Control-Allow-Methods"
41corsAllowHeadersHeader string = "Access-Control-Allow-Headers"
42corsAllowCredentialsHeader string = "Access-Control-Allow-Credentials"
43corsRequestMethodHeader string = "Access-Control-Request-Method"
44corsRequestHeadersHeader string = "Access-Control-Request-Headers"
45corsOriginHeader string = "Origin"
46corsVaryHeader string = "Vary"
47corsOriginMatchAll string = "*"
48)
49
50func (ch *cors) ServeHTTP(w http.ResponseWriter, r *http.Request) {
51origin := r.Header.Get(corsOriginHeader)
52if !ch.isOriginAllowed(origin) {
53if r.Method != corsOptionMethod || ch.ignoreOptions {
54ch.h.ServeHTTP(w, r)
55}
56
57return
58}
59
60if r.Method == corsOptionMethod {
61if ch.ignoreOptions {
62ch.h.ServeHTTP(w, r)
63return
64}
65
66if _, ok := r.Header[corsRequestMethodHeader]; !ok {
67w.WriteHeader(http.StatusBadRequest)
68return
69}
70
71method := r.Header.Get(corsRequestMethodHeader)
72if !ch.isMatch(method, ch.allowedMethods) {
73w.WriteHeader(http.StatusMethodNotAllowed)
74return
75}
76
77requestHeaders := strings.Split(r.Header.Get(corsRequestHeadersHeader), ",")
78allowedHeaders := []string{}
79for _, v := range requestHeaders {
80canonicalHeader := http.CanonicalHeaderKey(strings.TrimSpace(v))
81if canonicalHeader == "" || ch.isMatch(canonicalHeader, defaultCorsHeaders) {
82continue
83}
84
85if !ch.isMatch(canonicalHeader, ch.allowedHeaders) {
86w.WriteHeader(http.StatusForbidden)
87return
88}
89
90allowedHeaders = append(allowedHeaders, canonicalHeader)
91}
92
93if len(allowedHeaders) > 0 {
94w.Header().Set(corsAllowHeadersHeader, strings.Join(allowedHeaders, ","))
95}
96
97if ch.maxAge > 0 {
98w.Header().Set(corsMaxAgeHeader, strconv.Itoa(ch.maxAge))
99}
100
101if !ch.isMatch(method, defaultCorsMethods) {
102w.Header().Set(corsAllowMethodsHeader, method)
103}
104} else if len(ch.exposedHeaders) > 0 {
105w.Header().Set(corsExposeHeadersHeader, strings.Join(ch.exposedHeaders, ","))
106}
107
108if ch.allowCredentials {
109w.Header().Set(corsAllowCredentialsHeader, "true")
110}
111
112if len(ch.allowedOrigins) > 1 {
113w.Header().Set(corsVaryHeader, corsOriginHeader)
114}
115
116returnOrigin := origin
117if ch.allowedOriginValidator == nil && len(ch.allowedOrigins) == 0 {
118returnOrigin = "*"
119} else {
120for _, o := range ch.allowedOrigins {
121// A configuration of * is different than explicitly setting an allowed
122// origin. Returning arbitrary origin headers in an access control allow
123// origin header is unsafe and is not required by any use case.
124if o == corsOriginMatchAll {
125returnOrigin = "*"
126break
127}
128}
129}
130w.Header().Set(corsAllowOriginHeader, returnOrigin)
131
132if r.Method == corsOptionMethod {
133w.WriteHeader(ch.optionStatusCode)
134return
135}
136ch.h.ServeHTTP(w, r)
137}
138
139// CORS provides Cross-Origin Resource Sharing middleware.
140// Example:
141//
142// import (
143// "net/http"
144//
145// "github.com/gorilla/handlers"
146// "github.com/gorilla/mux"
147// )
148//
149// func main() {
150// r := mux.NewRouter()
151// r.HandleFunc("/users", UserEndpoint)
152// r.HandleFunc("/projects", ProjectEndpoint)
153//
154// // Apply the CORS middleware to our top-level router, with the defaults.
155// http.ListenAndServe(":8000", handlers.CORS()(r))
156// }
157func CORS(opts ...CORSOption) func(http.Handler) http.Handler {
158return func(h http.Handler) http.Handler {
159ch := parseCORSOptions(opts...)
160ch.h = h
161return ch
162}
163}
164
165func parseCORSOptions(opts ...CORSOption) *cors {
166ch := &cors{
167allowedMethods: defaultCorsMethods,
168allowedHeaders: defaultCorsHeaders,
169allowedOrigins: []string{},
170optionStatusCode: defaultCorsOptionStatusCode,
171}
172
173for _, option := range opts {
174_ = option(ch) //TODO: @bharat-rajani, return error to caller if not nil?
175}
176
177return ch
178}
179
180//
181// Functional options for configuring CORS.
182//
183
184// AllowedHeaders adds the provided headers to the list of allowed headers in a
185// CORS request.
186// This is an append operation so the headers Accept, Accept-Language,
187// and Content-Language are always allowed.
188// Content-Type must be explicitly declared if accepting Content-Types other than
189// application/x-www-form-urlencoded, multipart/form-data, or text/plain.
190func AllowedHeaders(headers []string) CORSOption {
191return func(ch *cors) error {
192for _, v := range headers {
193normalizedHeader := http.CanonicalHeaderKey(strings.TrimSpace(v))
194if normalizedHeader == "" {
195continue
196}
197
198if !ch.isMatch(normalizedHeader, ch.allowedHeaders) {
199ch.allowedHeaders = append(ch.allowedHeaders, normalizedHeader)
200}
201}
202
203return nil
204}
205}
206
207// AllowedMethods can be used to explicitly allow methods in the
208// Access-Control-Allow-Methods header.
209// This is a replacement operation so you must also
210// pass GET, HEAD, and POST if you wish to support those methods.
211func AllowedMethods(methods []string) CORSOption {
212return func(ch *cors) error {
213ch.allowedMethods = []string{}
214for _, v := range methods {
215normalizedMethod := strings.ToUpper(strings.TrimSpace(v))
216if normalizedMethod == "" {
217continue
218}
219
220if !ch.isMatch(normalizedMethod, ch.allowedMethods) {
221ch.allowedMethods = append(ch.allowedMethods, normalizedMethod)
222}
223}
224
225return nil
226}
227}
228
229// AllowedOrigins sets the allowed origins for CORS requests, as used in the
230// 'Allow-Access-Control-Origin' HTTP header.
231// Note: Passing in a []string{"*"} will allow any domain.
232func AllowedOrigins(origins []string) CORSOption {
233return func(ch *cors) error {
234for _, v := range origins {
235if v == corsOriginMatchAll {
236ch.allowedOrigins = []string{corsOriginMatchAll}
237return nil
238}
239}
240
241ch.allowedOrigins = origins
242return nil
243}
244}
245
246// AllowedOriginValidator sets a function for evaluating allowed origins in CORS requests, represented by the
247// 'Allow-Access-Control-Origin' HTTP header.
248func AllowedOriginValidator(fn OriginValidator) CORSOption {
249return func(ch *cors) error {
250ch.allowedOriginValidator = fn
251return nil
252}
253}
254
255// OptionStatusCode sets a custom status code on the OPTIONS requests.
256// Default behaviour sets it to 200 to reflect best practices. This is option is not mandatory
257// and can be used if you need a custom status code (i.e 204).
258//
259// More informations on the spec:
260// https://fetch.spec.whatwg.org/#cors-preflight-fetch
261func OptionStatusCode(code int) CORSOption {
262return func(ch *cors) error {
263ch.optionStatusCode = code
264return nil
265}
266}
267
268// ExposedHeaders can be used to specify headers that are available
269// and will not be stripped out by the user-agent.
270func ExposedHeaders(headers []string) CORSOption {
271return func(ch *cors) error {
272ch.exposedHeaders = []string{}
273for _, v := range headers {
274normalizedHeader := http.CanonicalHeaderKey(strings.TrimSpace(v))
275if normalizedHeader == "" {
276continue
277}
278
279if !ch.isMatch(normalizedHeader, ch.exposedHeaders) {
280ch.exposedHeaders = append(ch.exposedHeaders, normalizedHeader)
281}
282}
283
284return nil
285}
286}
287
288// MaxAge determines the maximum age (in seconds) between preflight requests. A
289// maximum of 10 minutes is allowed. An age above this value will default to 10
290// minutes.
291func MaxAge(age int) CORSOption {
292return func(ch *cors) error {
293// Maximum of 10 minutes.
294if age > 600 {
295age = 600
296}
297
298ch.maxAge = age
299return nil
300}
301}
302
303// IgnoreOptions causes the CORS middleware to ignore OPTIONS requests, instead
304// passing them through to the next handler. This is useful when your application
305// or framework has a pre-existing mechanism for responding to OPTIONS requests.
306func IgnoreOptions() CORSOption {
307return func(ch *cors) error {
308ch.ignoreOptions = true
309return nil
310}
311}
312
313// AllowCredentials can be used to specify that the user agent may pass
314// authentication details along with the request.
315func AllowCredentials() CORSOption {
316return func(ch *cors) error {
317ch.allowCredentials = true
318return nil
319}
320}
321
322func (ch *cors) isOriginAllowed(origin string) bool {
323if origin == "" {
324return false
325}
326
327if ch.allowedOriginValidator != nil {
328return ch.allowedOriginValidator(origin)
329}
330
331if len(ch.allowedOrigins) == 0 {
332return true
333}
334
335for _, allowedOrigin := range ch.allowedOrigins {
336if allowedOrigin == origin || allowedOrigin == corsOriginMatchAll {
337return true
338}
339}
340
341return false
342}
343
344func (ch *cors) isMatch(needle string, haystack []string) bool {
345for _, v := range haystack {
346if v == needle {
347return true
348}
349}
350
351return false
352}
353