netramesh

Форк
0
/
clientserver_test.go 
1481 строка · 39.9 Кб
1
// Copyright 2015 The Go Authors. All rights reserved.
2
// Use of this source code is governed by a BSD-style
3
// license that can be found in the LICENSE file.
4

5
// Tests that use both the client & server, in both HTTP/1 and HTTP/2 mode.
6

7
package http_test
8

9
import (
10
	"bytes"
11
	"compress/gzip"
12
	"crypto/tls"
13
	"fmt"
14
	"io"
15
	"io/ioutil"
16
	"log"
17
	"net"
18
	. "net/http"
19
	"net/http/httptest"
20
	"net/http/httputil"
21
	"net/url"
22
	"os"
23
	"reflect"
24
	"runtime"
25
	"sort"
26
	"strings"
27
	"sync"
28
	"sync/atomic"
29
	"testing"
30
	"time"
31
)
32

33
type clientServerTest struct {
34
	t  *testing.T
35
	h2 bool
36
	h  Handler
37
	ts *httptest.Server
38
	tr *Transport
39
	c  *Client
40
}
41

42
func (t *clientServerTest) close() {
43
	t.tr.CloseIdleConnections()
44
	t.ts.Close()
45
}
46

47
func (t *clientServerTest) getURL(u string) string {
48
	res, err := t.c.Get(u)
49
	if err != nil {
50
		t.t.Fatal(err)
51
	}
52
	defer res.Body.Close()
53
	slurp, err := ioutil.ReadAll(res.Body)
54
	if err != nil {
55
		t.t.Fatal(err)
56
	}
57
	return string(slurp)
58
}
59

60
func (t *clientServerTest) scheme() string {
61
	if t.h2 {
62
		return "https"
63
	}
64
	return "http"
65
}
66

67
const (
68
	h1Mode = false
69
	h2Mode = true
70
)
71

72
var optQuietLog = func(ts *httptest.Server) {
73
	ts.Config.ErrorLog = quietLog
74
}
75

76
func newClientServerTest(t *testing.T, h2 bool, h Handler, opts ...interface{}) *clientServerTest {
77
	cst := &clientServerTest{
78
		t:  t,
79
		h2: h2,
80
		h:  h,
81
		tr: &Transport{},
82
	}
83
	cst.c = &Client{Transport: cst.tr}
84
	cst.ts = httptest.NewUnstartedServer(h)
85

86
	for _, opt := range opts {
87
		switch opt := opt.(type) {
88
		case func(*Transport):
89
			opt(cst.tr)
90
		case func(*httptest.Server):
91
			opt(cst.ts)
92
		default:
93
			t.Fatalf("unhandled option type %T", opt)
94
		}
95
	}
96

97
	if !h2 {
98
		cst.ts.Start()
99
		return cst
100
	}
101
	ExportHttp2ConfigureServer(cst.ts.Config, nil)
102
	cst.ts.TLS = cst.ts.Config.TLSConfig
103
	cst.ts.StartTLS()
104

105
	cst.tr.TLSClientConfig = &tls.Config{
106
		InsecureSkipVerify: true,
107
	}
108
	if err := ExportHttp2ConfigureTransport(cst.tr); err != nil {
109
		t.Fatal(err)
110
	}
111
	return cst
112
}
113

114
// Testing the newClientServerTest helper itself.
115
func TestNewClientServerTest(t *testing.T) {
116
	var got struct {
117
		sync.Mutex
118
		log []string
119
	}
120
	h := HandlerFunc(func(w ResponseWriter, r *Request) {
121
		got.Lock()
122
		defer got.Unlock()
123
		got.log = append(got.log, r.Proto)
124
	})
125
	for _, v := range [2]bool{false, true} {
126
		cst := newClientServerTest(t, v, h)
127
		if _, err := cst.c.Head(cst.ts.URL); err != nil {
128
			t.Fatal(err)
129
		}
130
		cst.close()
131
	}
132
	got.Lock() // no need to unlock
133
	if want := []string{"HTTP/1.1", "HTTP/2.0"}; !reflect.DeepEqual(got.log, want) {
134
		t.Errorf("got %q; want %q", got.log, want)
135
	}
136
}
137

138
func TestChunkedResponseHeaders_h1(t *testing.T) { testChunkedResponseHeaders(t, h1Mode) }
139
func TestChunkedResponseHeaders_h2(t *testing.T) { testChunkedResponseHeaders(t, h2Mode) }
140

141
func testChunkedResponseHeaders(t *testing.T, h2 bool) {
142
	defer afterTest(t)
143
	log.SetOutput(ioutil.Discard) // is noisy otherwise
144
	defer log.SetOutput(os.Stderr)
145
	cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
146
		w.Header().Set("Content-Length", "intentional gibberish") // we check that this is deleted
147
		w.(Flusher).Flush()
148
		fmt.Fprintf(w, "I am a chunked response.")
149
	}))
150
	defer cst.close()
151

152
	res, err := cst.c.Get(cst.ts.URL)
153
	if err != nil {
154
		t.Fatalf("Get error: %v", err)
155
	}
156
	defer res.Body.Close()
157
	if g, e := res.ContentLength, int64(-1); g != e {
158
		t.Errorf("expected ContentLength of %d; got %d", e, g)
159
	}
160
	wantTE := []string{"chunked"}
161
	if h2 {
162
		wantTE = nil
163
	}
164
	if !reflect.DeepEqual(res.TransferEncoding, wantTE) {
165
		t.Errorf("TransferEncoding = %v; want %v", res.TransferEncoding, wantTE)
166
	}
167
	if got, haveCL := res.Header["Content-Length"]; haveCL {
168
		t.Errorf("Unexpected Content-Length: %q", got)
169
	}
170
}
171

172
type reqFunc func(c *Client, url string) (*Response, error)
173

174
// h12Compare is a test that compares HTTP/1 and HTTP/2 behavior
175
// against each other.
176
type h12Compare struct {
177
	Handler            func(ResponseWriter, *Request)    // required
178
	ReqFunc            reqFunc                           // optional
179
	CheckResponse      func(proto string, res *Response) // optional
180
	EarlyCheckResponse func(proto string, res *Response) // optional; pre-normalize
181
	Opts               []interface{}
182
}
183

184
func (tt h12Compare) reqFunc() reqFunc {
185
	if tt.ReqFunc == nil {
186
		return (*Client).Get
187
	}
188
	return tt.ReqFunc
189
}
190

191
func (tt h12Compare) run(t *testing.T) {
192
	setParallel(t)
193
	cst1 := newClientServerTest(t, false, HandlerFunc(tt.Handler), tt.Opts...)
194
	defer cst1.close()
195
	cst2 := newClientServerTest(t, true, HandlerFunc(tt.Handler), tt.Opts...)
196
	defer cst2.close()
197

198
	res1, err := tt.reqFunc()(cst1.c, cst1.ts.URL)
199
	if err != nil {
200
		t.Errorf("HTTP/1 request: %v", err)
201
		return
202
	}
203
	res2, err := tt.reqFunc()(cst2.c, cst2.ts.URL)
204
	if err != nil {
205
		t.Errorf("HTTP/2 request: %v", err)
206
		return
207
	}
208

209
	if fn := tt.EarlyCheckResponse; fn != nil {
210
		fn("HTTP/1.1", res1)
211
		fn("HTTP/2.0", res2)
212
	}
213

214
	tt.normalizeRes(t, res1, "HTTP/1.1")
215
	tt.normalizeRes(t, res2, "HTTP/2.0")
216
	res1body, res2body := res1.Body, res2.Body
217

218
	eres1 := mostlyCopy(res1)
219
	eres2 := mostlyCopy(res2)
220
	if !reflect.DeepEqual(eres1, eres2) {
221
		t.Errorf("Response headers to handler differed:\nhttp/1 (%v):\n\t%#v\nhttp/2 (%v):\n\t%#v",
222
			cst1.ts.URL, eres1, cst2.ts.URL, eres2)
223
	}
224
	if !reflect.DeepEqual(res1body, res2body) {
225
		t.Errorf("Response bodies to handler differed.\nhttp1: %v\nhttp2: %v\n", res1body, res2body)
226
	}
227
	if fn := tt.CheckResponse; fn != nil {
228
		res1.Body, res2.Body = res1body, res2body
229
		fn("HTTP/1.1", res1)
230
		fn("HTTP/2.0", res2)
231
	}
232
}
233

234
func mostlyCopy(r *Response) *Response {
235
	c := *r
236
	c.Body = nil
237
	c.TransferEncoding = nil
238
	c.TLS = nil
239
	c.Request = nil
240
	return &c
241
}
242

243
type slurpResult struct {
244
	io.ReadCloser
245
	body []byte
246
	err  error
247
}
248

249
func (sr slurpResult) String() string { return fmt.Sprintf("body %q; err %v", sr.body, sr.err) }
250

251
func (tt h12Compare) normalizeRes(t *testing.T, res *Response, wantProto string) {
252
	if res.Proto == wantProto {
253
		res.Proto, res.ProtoMajor, res.ProtoMinor = "", 0, 0
254
	} else {
255
		t.Errorf("got %q response; want %q", res.Proto, wantProto)
256
	}
257
	slurp, err := ioutil.ReadAll(res.Body)
258

259
	res.Body.Close()
260
	res.Body = slurpResult{
261
		ReadCloser: ioutil.NopCloser(bytes.NewReader(slurp)),
262
		body:       slurp,
263
		err:        err,
264
	}
265
	for i, v := range res.Header["Date"] {
266
		res.Header["Date"][i] = strings.Repeat("x", len(v))
267
	}
268
	if res.Request == nil {
269
		t.Errorf("for %s, no request", wantProto)
270
	}
271
	if (res.TLS != nil) != (wantProto == "HTTP/2.0") {
272
		t.Errorf("TLS set = %v; want %v", res.TLS != nil, res.TLS == nil)
273
	}
274
}
275

276
// Issue 13532
277
func TestH12_HeadContentLengthNoBody(t *testing.T) {
278
	h12Compare{
279
		ReqFunc: (*Client).Head,
280
		Handler: func(w ResponseWriter, r *Request) {
281
		},
282
	}.run(t)
283
}
284

285
func TestH12_HeadContentLengthSmallBody(t *testing.T) {
286
	h12Compare{
287
		ReqFunc: (*Client).Head,
288
		Handler: func(w ResponseWriter, r *Request) {
289
			io.WriteString(w, "small")
290
		},
291
	}.run(t)
292
}
293

294
func TestH12_HeadContentLengthLargeBody(t *testing.T) {
295
	h12Compare{
296
		ReqFunc: (*Client).Head,
297
		Handler: func(w ResponseWriter, r *Request) {
298
			chunk := strings.Repeat("x", 512<<10)
299
			for i := 0; i < 10; i++ {
300
				io.WriteString(w, chunk)
301
			}
302
		},
303
	}.run(t)
304
}
305

306
func TestH12_200NoBody(t *testing.T) {
307
	h12Compare{Handler: func(w ResponseWriter, r *Request) {}}.run(t)
308
}
309

310
func TestH2_204NoBody(t *testing.T) { testH12_noBody(t, 204) }
311
func TestH2_304NoBody(t *testing.T) { testH12_noBody(t, 304) }
312
func TestH2_404NoBody(t *testing.T) { testH12_noBody(t, 404) }
313

314
func testH12_noBody(t *testing.T, status int) {
315
	h12Compare{Handler: func(w ResponseWriter, r *Request) {
316
		w.WriteHeader(status)
317
	}}.run(t)
318
}
319

320
func TestH12_SmallBody(t *testing.T) {
321
	h12Compare{Handler: func(w ResponseWriter, r *Request) {
322
		io.WriteString(w, "small body")
323
	}}.run(t)
324
}
325

326
func TestH12_ExplicitContentLength(t *testing.T) {
327
	h12Compare{Handler: func(w ResponseWriter, r *Request) {
328
		w.Header().Set("Content-Length", "3")
329
		io.WriteString(w, "foo")
330
	}}.run(t)
331
}
332

333
func TestH12_FlushBeforeBody(t *testing.T) {
334
	h12Compare{Handler: func(w ResponseWriter, r *Request) {
335
		w.(Flusher).Flush()
336
		io.WriteString(w, "foo")
337
	}}.run(t)
338
}
339

340
func TestH12_FlushMidBody(t *testing.T) {
341
	h12Compare{Handler: func(w ResponseWriter, r *Request) {
342
		io.WriteString(w, "foo")
343
		w.(Flusher).Flush()
344
		io.WriteString(w, "bar")
345
	}}.run(t)
346
}
347

348
func TestH12_Head_ExplicitLen(t *testing.T) {
349
	h12Compare{
350
		ReqFunc: (*Client).Head,
351
		Handler: func(w ResponseWriter, r *Request) {
352
			if r.Method != "HEAD" {
353
				t.Errorf("unexpected method %q", r.Method)
354
			}
355
			w.Header().Set("Content-Length", "1235")
356
		},
357
	}.run(t)
358
}
359

360
func TestH12_Head_ImplicitLen(t *testing.T) {
361
	h12Compare{
362
		ReqFunc: (*Client).Head,
363
		Handler: func(w ResponseWriter, r *Request) {
364
			if r.Method != "HEAD" {
365
				t.Errorf("unexpected method %q", r.Method)
366
			}
367
			io.WriteString(w, "foo")
368
		},
369
	}.run(t)
370
}
371

372
func TestH12_HandlerWritesTooLittle(t *testing.T) {
373
	h12Compare{
374
		Handler: func(w ResponseWriter, r *Request) {
375
			w.Header().Set("Content-Length", "3")
376
			io.WriteString(w, "12") // one byte short
377
		},
378
		CheckResponse: func(proto string, res *Response) {
379
			sr, ok := res.Body.(slurpResult)
380
			if !ok {
381
				t.Errorf("%s body is %T; want slurpResult", proto, res.Body)
382
				return
383
			}
384
			if sr.err != io.ErrUnexpectedEOF {
385
				t.Errorf("%s read error = %v; want io.ErrUnexpectedEOF", proto, sr.err)
386
			}
387
			if string(sr.body) != "12" {
388
				t.Errorf("%s body = %q; want %q", proto, sr.body, "12")
389
			}
390
		},
391
	}.run(t)
392
}
393

394
// Tests that the HTTP/1 and HTTP/2 servers prevent handlers from
395
// writing more than they declared. This test does not test whether
396
// the transport deals with too much data, though, since the server
397
// doesn't make it possible to send bogus data. For those tests, see
398
// transport_test.go (for HTTP/1) or x/net/http2/transport_test.go
399
// (for HTTP/2).
400
func TestH12_HandlerWritesTooMuch(t *testing.T) {
401
	h12Compare{
402
		Handler: func(w ResponseWriter, r *Request) {
403
			w.Header().Set("Content-Length", "3")
404
			w.(Flusher).Flush()
405
			io.WriteString(w, "123")
406
			w.(Flusher).Flush()
407
			n, err := io.WriteString(w, "x") // too many
408
			if n > 0 || err == nil {
409
				t.Errorf("for proto %q, final write = %v, %v; want 0, some error", r.Proto, n, err)
410
			}
411
		},
412
	}.run(t)
413
}
414

415
// Verify that both our HTTP/1 and HTTP/2 request and auto-decompress gzip.
416
// Some hosts send gzip even if you don't ask for it; see golang.org/issue/13298
417
func TestH12_AutoGzip(t *testing.T) {
418
	h12Compare{
419
		Handler: func(w ResponseWriter, r *Request) {
420
			if ae := r.Header.Get("Accept-Encoding"); ae != "gzip" {
421
				t.Errorf("%s Accept-Encoding = %q; want gzip", r.Proto, ae)
422
			}
423
			w.Header().Set("Content-Encoding", "gzip")
424
			gz := gzip.NewWriter(w)
425
			io.WriteString(gz, "I am some gzipped content. Go go go go go go go go go go go go should compress well.")
426
			gz.Close()
427
		},
428
	}.run(t)
429
}
430

431
func TestH12_AutoGzip_Disabled(t *testing.T) {
432
	h12Compare{
433
		Opts: []interface{}{
434
			func(tr *Transport) { tr.DisableCompression = true },
435
		},
436
		Handler: func(w ResponseWriter, r *Request) {
437
			fmt.Fprintf(w, "%q", r.Header["Accept-Encoding"])
438
			if ae := r.Header.Get("Accept-Encoding"); ae != "" {
439
				t.Errorf("%s Accept-Encoding = %q; want empty", r.Proto, ae)
440
			}
441
		},
442
	}.run(t)
443
}
444

445
// Test304Responses verifies that 304s don't declare that they're
446
// chunking in their response headers and aren't allowed to produce
447
// output.
448
func Test304Responses_h1(t *testing.T) { test304Responses(t, h1Mode) }
449
func Test304Responses_h2(t *testing.T) { test304Responses(t, h2Mode) }
450

451
func test304Responses(t *testing.T, h2 bool) {
452
	defer afterTest(t)
453
	cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
454
		w.WriteHeader(StatusNotModified)
455
		_, err := w.Write([]byte("illegal body"))
456
		if err != ErrBodyNotAllowed {
457
			t.Errorf("on Write, expected ErrBodyNotAllowed, got %v", err)
458
		}
459
	}))
460
	defer cst.close()
461
	res, err := cst.c.Get(cst.ts.URL)
462
	if err != nil {
463
		t.Fatal(err)
464
	}
465
	if len(res.TransferEncoding) > 0 {
466
		t.Errorf("expected no TransferEncoding; got %v", res.TransferEncoding)
467
	}
468
	body, err := ioutil.ReadAll(res.Body)
469
	if err != nil {
470
		t.Error(err)
471
	}
472
	if len(body) > 0 {
473
		t.Errorf("got unexpected body %q", string(body))
474
	}
475
}
476

477
func TestH12_ServerEmptyContentLength(t *testing.T) {
478
	h12Compare{
479
		Handler: func(w ResponseWriter, r *Request) {
480
			w.Header()["Content-Type"] = []string{""}
481
			io.WriteString(w, "<html><body>hi</body></html>")
482
		},
483
	}.run(t)
484
}
485

486
func TestH12_RequestContentLength_Known_NonZero(t *testing.T) {
487
	h12requestContentLength(t, func() io.Reader { return strings.NewReader("FOUR") }, 4)
488
}
489

490
func TestH12_RequestContentLength_Known_Zero(t *testing.T) {
491
	h12requestContentLength(t, func() io.Reader { return nil }, 0)
492
}
493

494
func TestH12_RequestContentLength_Unknown(t *testing.T) {
495
	h12requestContentLength(t, func() io.Reader { return struct{ io.Reader }{strings.NewReader("Stuff")} }, -1)
496
}
497

498
func h12requestContentLength(t *testing.T, bodyfn func() io.Reader, wantLen int64) {
499
	h12Compare{
500
		Handler: func(w ResponseWriter, r *Request) {
501
			w.Header().Set("Got-Length", fmt.Sprint(r.ContentLength))
502
			fmt.Fprintf(w, "Req.ContentLength=%v", r.ContentLength)
503
		},
504
		ReqFunc: func(c *Client, url string) (*Response, error) {
505
			return c.Post(url, "text/plain", bodyfn())
506
		},
507
		CheckResponse: func(proto string, res *Response) {
508
			if got, want := res.Header.Get("Got-Length"), fmt.Sprint(wantLen); got != want {
509
				t.Errorf("Proto %q got length %q; want %q", proto, got, want)
510
			}
511
		},
512
	}.run(t)
513
}
514

515
// Tests that closing the Request.Cancel channel also while still
516
// reading the response body. Issue 13159.
517
func TestCancelRequestMidBody_h1(t *testing.T) { testCancelRequestMidBody(t, h1Mode) }
518
func TestCancelRequestMidBody_h2(t *testing.T) { testCancelRequestMidBody(t, h2Mode) }
519
func testCancelRequestMidBody(t *testing.T, h2 bool) {
520
	defer afterTest(t)
521
	unblock := make(chan bool)
522
	didFlush := make(chan bool, 1)
523
	cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
524
		io.WriteString(w, "Hello")
525
		w.(Flusher).Flush()
526
		didFlush <- true
527
		<-unblock
528
		io.WriteString(w, ", world.")
529
	}))
530
	defer cst.close()
531
	defer close(unblock)
532

533
	req, _ := NewRequest("GET", cst.ts.URL, nil)
534
	cancel := make(chan struct{})
535
	req.Cancel = cancel
536

537
	res, err := cst.c.Do(req)
538
	if err != nil {
539
		t.Fatal(err)
540
	}
541
	defer res.Body.Close()
542
	<-didFlush
543

544
	// Read a bit before we cancel. (Issue 13626)
545
	// We should have "Hello" at least sitting there.
546
	firstRead := make([]byte, 10)
547
	n, err := res.Body.Read(firstRead)
548
	if err != nil {
549
		t.Fatal(err)
550
	}
551
	firstRead = firstRead[:n]
552

553
	close(cancel)
554

555
	rest, err := ioutil.ReadAll(res.Body)
556
	all := string(firstRead) + string(rest)
557
	if all != "Hello" {
558
		t.Errorf("Read %q (%q + %q); want Hello", all, firstRead, rest)
559
	}
560
	if !reflect.DeepEqual(err, ExportErrRequestCanceled) {
561
		t.Errorf("ReadAll error = %v; want %v", err, ExportErrRequestCanceled)
562
	}
563
}
564

565
// Tests that clients can send trailers to a server and that the server can read them.
566
func TestTrailersClientToServer_h1(t *testing.T) { testTrailersClientToServer(t, h1Mode) }
567
func TestTrailersClientToServer_h2(t *testing.T) { testTrailersClientToServer(t, h2Mode) }
568

569
func testTrailersClientToServer(t *testing.T, h2 bool) {
570
	defer afterTest(t)
571
	cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
572
		var decl []string
573
		for k := range r.Trailer {
574
			decl = append(decl, k)
575
		}
576
		sort.Strings(decl)
577

578
		slurp, err := ioutil.ReadAll(r.Body)
579
		if err != nil {
580
			t.Errorf("Server reading request body: %v", err)
581
		}
582
		if string(slurp) != "foo" {
583
			t.Errorf("Server read request body %q; want foo", slurp)
584
		}
585
		if r.Trailer == nil {
586
			io.WriteString(w, "nil Trailer")
587
		} else {
588
			fmt.Fprintf(w, "decl: %v, vals: %s, %s",
589
				decl,
590
				r.Trailer.Get("Client-Trailer-A"),
591
				r.Trailer.Get("Client-Trailer-B"))
592
		}
593
	}))
594
	defer cst.close()
595

596
	var req *Request
597
	req, _ = NewRequest("POST", cst.ts.URL, io.MultiReader(
598
		eofReaderFunc(func() {
599
			req.Trailer["Client-Trailer-A"] = []string{"valuea"}
600
		}),
601
		strings.NewReader("foo"),
602
		eofReaderFunc(func() {
603
			req.Trailer["Client-Trailer-B"] = []string{"valueb"}
604
		}),
605
	))
606
	req.Trailer = Header{
607
		"Client-Trailer-A": nil, //  to be set later
608
		"Client-Trailer-B": nil, //  to be set later
609
	}
610
	req.ContentLength = -1
611
	res, err := cst.c.Do(req)
612
	if err != nil {
613
		t.Fatal(err)
614
	}
615
	if err := wantBody(res, err, "decl: [Client-Trailer-A Client-Trailer-B], vals: valuea, valueb"); err != nil {
616
		t.Error(err)
617
	}
618
}
619

620
// Tests that servers send trailers to a client and that the client can read them.
621
func TestTrailersServerToClient_h1(t *testing.T)       { testTrailersServerToClient(t, h1Mode, false) }
622
func TestTrailersServerToClient_h2(t *testing.T)       { testTrailersServerToClient(t, h2Mode, false) }
623
func TestTrailersServerToClient_Flush_h1(t *testing.T) { testTrailersServerToClient(t, h1Mode, true) }
624
func TestTrailersServerToClient_Flush_h2(t *testing.T) { testTrailersServerToClient(t, h2Mode, true) }
625

626
func testTrailersServerToClient(t *testing.T, h2, flush bool) {
627
	defer afterTest(t)
628
	const body = "Some body"
629
	cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
630
		w.Header().Set("Trailer", "Server-Trailer-A, Server-Trailer-B")
631
		w.Header().Add("Trailer", "Server-Trailer-C")
632

633
		io.WriteString(w, body)
634
		if flush {
635
			w.(Flusher).Flush()
636
		}
637

638
		// How handlers set Trailers: declare it ahead of time
639
		// with the Trailer header, and then mutate the
640
		// Header() of those values later, after the response
641
		// has been written (we wrote to w above).
642
		w.Header().Set("Server-Trailer-A", "valuea")
643
		w.Header().Set("Server-Trailer-C", "valuec") // skipping B
644
		w.Header().Set("Server-Trailer-NotDeclared", "should be omitted")
645
	}))
646
	defer cst.close()
647

648
	res, err := cst.c.Get(cst.ts.URL)
649
	if err != nil {
650
		t.Fatal(err)
651
	}
652

653
	wantHeader := Header{
654
		"Content-Type": {"text/plain; charset=utf-8"},
655
	}
656
	wantLen := -1
657
	if h2 && !flush {
658
		// In HTTP/1.1, any use of trailers forces HTTP/1.1
659
		// chunking and a flush at the first write. That's
660
		// unnecessary with HTTP/2's framing, so the server
661
		// is able to calculate the length while still sending
662
		// trailers afterwards.
663
		wantLen = len(body)
664
		wantHeader["Content-Length"] = []string{fmt.Sprint(wantLen)}
665
	}
666
	if res.ContentLength != int64(wantLen) {
667
		t.Errorf("ContentLength = %v; want %v", res.ContentLength, wantLen)
668
	}
669

670
	delete(res.Header, "Date") // irrelevant for test
671
	if !reflect.DeepEqual(res.Header, wantHeader) {
672
		t.Errorf("Header = %v; want %v", res.Header, wantHeader)
673
	}
674

675
	if got, want := res.Trailer, (Header{
676
		"Server-Trailer-A": nil,
677
		"Server-Trailer-B": nil,
678
		"Server-Trailer-C": nil,
679
	}); !reflect.DeepEqual(got, want) {
680
		t.Errorf("Trailer before body read = %v; want %v", got, want)
681
	}
682

683
	if err := wantBody(res, nil, body); err != nil {
684
		t.Fatal(err)
685
	}
686

687
	if got, want := res.Trailer, (Header{
688
		"Server-Trailer-A": {"valuea"},
689
		"Server-Trailer-B": nil,
690
		"Server-Trailer-C": {"valuec"},
691
	}); !reflect.DeepEqual(got, want) {
692
		t.Errorf("Trailer after body read = %v; want %v", got, want)
693
	}
694
}
695

696
// Don't allow a Body.Read after Body.Close. Issue 13648.
697
func TestResponseBodyReadAfterClose_h1(t *testing.T) { testResponseBodyReadAfterClose(t, h1Mode) }
698
func TestResponseBodyReadAfterClose_h2(t *testing.T) { testResponseBodyReadAfterClose(t, h2Mode) }
699

700
func testResponseBodyReadAfterClose(t *testing.T, h2 bool) {
701
	defer afterTest(t)
702
	const body = "Some body"
703
	cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
704
		io.WriteString(w, body)
705
	}))
706
	defer cst.close()
707
	res, err := cst.c.Get(cst.ts.URL)
708
	if err != nil {
709
		t.Fatal(err)
710
	}
711
	res.Body.Close()
712
	data, err := ioutil.ReadAll(res.Body)
713
	if len(data) != 0 || err == nil {
714
		t.Fatalf("ReadAll returned %q, %v; want error", data, err)
715
	}
716
}
717

718
func TestConcurrentReadWriteReqBody_h1(t *testing.T) { testConcurrentReadWriteReqBody(t, h1Mode) }
719
func TestConcurrentReadWriteReqBody_h2(t *testing.T) { testConcurrentReadWriteReqBody(t, h2Mode) }
720
func testConcurrentReadWriteReqBody(t *testing.T, h2 bool) {
721
	defer afterTest(t)
722
	const reqBody = "some request body"
723
	const resBody = "some response body"
724
	cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
725
		var wg sync.WaitGroup
726
		wg.Add(2)
727
		didRead := make(chan bool, 1)
728
		// Read in one goroutine.
729
		go func() {
730
			defer wg.Done()
731
			data, err := ioutil.ReadAll(r.Body)
732
			if string(data) != reqBody {
733
				t.Errorf("Handler read %q; want %q", data, reqBody)
734
			}
735
			if err != nil {
736
				t.Errorf("Handler Read: %v", err)
737
			}
738
			didRead <- true
739
		}()
740
		// Write in another goroutine.
741
		go func() {
742
			defer wg.Done()
743
			if !h2 {
744
				// our HTTP/1 implementation intentionally
745
				// doesn't permit writes during read (mostly
746
				// due to it being undefined); if that is ever
747
				// relaxed, change this.
748
				<-didRead
749
			}
750
			io.WriteString(w, resBody)
751
		}()
752
		wg.Wait()
753
	}))
754
	defer cst.close()
755
	req, _ := NewRequest("POST", cst.ts.URL, strings.NewReader(reqBody))
756
	req.Header.Add("Expect", "100-continue") // just to complicate things
757
	res, err := cst.c.Do(req)
758
	if err != nil {
759
		t.Fatal(err)
760
	}
761
	data, err := ioutil.ReadAll(res.Body)
762
	defer res.Body.Close()
763
	if err != nil {
764
		t.Fatal(err)
765
	}
766
	if string(data) != resBody {
767
		t.Errorf("read %q; want %q", data, resBody)
768
	}
769
}
770

771
func TestConnectRequest_h1(t *testing.T) { testConnectRequest(t, h1Mode) }
772
func TestConnectRequest_h2(t *testing.T) { testConnectRequest(t, h2Mode) }
773
func testConnectRequest(t *testing.T, h2 bool) {
774
	defer afterTest(t)
775
	gotc := make(chan *Request, 1)
776
	cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
777
		gotc <- r
778
	}))
779
	defer cst.close()
780

781
	u, err := url.Parse(cst.ts.URL)
782
	if err != nil {
783
		t.Fatal(err)
784
	}
785

786
	tests := []struct {
787
		req  *Request
788
		want string
789
	}{
790
		{
791
			req: &Request{
792
				Method: "CONNECT",
793
				Header: Header{},
794
				URL:    u,
795
			},
796
			want: u.Host,
797
		},
798
		{
799
			req: &Request{
800
				Method: "CONNECT",
801
				Header: Header{},
802
				URL:    u,
803
				Host:   "example.com:123",
804
			},
805
			want: "example.com:123",
806
		},
807
	}
808

809
	for i, tt := range tests {
810
		res, err := cst.c.Do(tt.req)
811
		if err != nil {
812
			t.Errorf("%d. RoundTrip = %v", i, err)
813
			continue
814
		}
815
		res.Body.Close()
816
		req := <-gotc
817
		if req.Method != "CONNECT" {
818
			t.Errorf("method = %q; want CONNECT", req.Method)
819
		}
820
		if req.Host != tt.want {
821
			t.Errorf("Host = %q; want %q", req.Host, tt.want)
822
		}
823
		if req.URL.Host != tt.want {
824
			t.Errorf("URL.Host = %q; want %q", req.URL.Host, tt.want)
825
		}
826
	}
827
}
828

829
func TestTransportUserAgent_h1(t *testing.T) { testTransportUserAgent(t, h1Mode) }
830
func TestTransportUserAgent_h2(t *testing.T) { testTransportUserAgent(t, h2Mode) }
831
func testTransportUserAgent(t *testing.T, h2 bool) {
832
	defer afterTest(t)
833
	cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
834
		fmt.Fprintf(w, "%q", r.Header["User-Agent"])
835
	}))
836
	defer cst.close()
837

838
	either := func(a, b string) string {
839
		if h2 {
840
			return b
841
		}
842
		return a
843
	}
844

845
	tests := []struct {
846
		setup func(*Request)
847
		want  string
848
	}{
849
		{
850
			func(r *Request) {},
851
			either(`["Go-http-client/1.1"]`, `["Go-http-client/2.0"]`),
852
		},
853
		{
854
			func(r *Request) { r.Header.Set("User-Agent", "foo/1.2.3") },
855
			`["foo/1.2.3"]`,
856
		},
857
		{
858
			func(r *Request) { r.Header["User-Agent"] = []string{"single", "or", "multiple"} },
859
			`["single"]`,
860
		},
861
		{
862
			func(r *Request) { r.Header.Set("User-Agent", "") },
863
			`[]`,
864
		},
865
		{
866
			func(r *Request) { r.Header["User-Agent"] = nil },
867
			`[]`,
868
		},
869
	}
870
	for i, tt := range tests {
871
		req, _ := NewRequest("GET", cst.ts.URL, nil)
872
		tt.setup(req)
873
		res, err := cst.c.Do(req)
874
		if err != nil {
875
			t.Errorf("%d. RoundTrip = %v", i, err)
876
			continue
877
		}
878
		slurp, err := ioutil.ReadAll(res.Body)
879
		res.Body.Close()
880
		if err != nil {
881
			t.Errorf("%d. read body = %v", i, err)
882
			continue
883
		}
884
		if string(slurp) != tt.want {
885
			t.Errorf("%d. body mismatch.\n got: %s\nwant: %s\n", i, slurp, tt.want)
886
		}
887
	}
888
}
889

890
func TestStarRequestFoo_h1(t *testing.T)     { testStarRequest(t, "FOO", h1Mode) }
891
func TestStarRequestFoo_h2(t *testing.T)     { testStarRequest(t, "FOO", h2Mode) }
892
func TestStarRequestOptions_h1(t *testing.T) { testStarRequest(t, "OPTIONS", h1Mode) }
893
func TestStarRequestOptions_h2(t *testing.T) { testStarRequest(t, "OPTIONS", h2Mode) }
894
func testStarRequest(t *testing.T, method string, h2 bool) {
895
	defer afterTest(t)
896
	gotc := make(chan *Request, 1)
897
	cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
898
		w.Header().Set("foo", "bar")
899
		gotc <- r
900
		w.(Flusher).Flush()
901
	}))
902
	defer cst.close()
903

904
	u, err := url.Parse(cst.ts.URL)
905
	if err != nil {
906
		t.Fatal(err)
907
	}
908
	u.Path = "*"
909

910
	req := &Request{
911
		Method: method,
912
		Header: Header{},
913
		URL:    u,
914
	}
915

916
	res, err := cst.c.Do(req)
917
	if err != nil {
918
		t.Fatalf("RoundTrip = %v", err)
919
	}
920
	res.Body.Close()
921

922
	wantFoo := "bar"
923
	wantLen := int64(-1)
924
	if method == "OPTIONS" {
925
		wantFoo = ""
926
		wantLen = 0
927
	}
928
	if res.StatusCode != 200 {
929
		t.Errorf("status code = %v; want %d", res.Status, 200)
930
	}
931
	if res.ContentLength != wantLen {
932
		t.Errorf("content length = %v; want %d", res.ContentLength, wantLen)
933
	}
934
	if got := res.Header.Get("foo"); got != wantFoo {
935
		t.Errorf("response \"foo\" header = %q; want %q", got, wantFoo)
936
	}
937
	select {
938
	case req = <-gotc:
939
	default:
940
		req = nil
941
	}
942
	if req == nil {
943
		if method != "OPTIONS" {
944
			t.Fatalf("handler never got request")
945
		}
946
		return
947
	}
948
	if req.Method != method {
949
		t.Errorf("method = %q; want %q", req.Method, method)
950
	}
951
	if req.URL.Path != "*" {
952
		t.Errorf("URL.Path = %q; want *", req.URL.Path)
953
	}
954
	if req.RequestURI != "*" {
955
		t.Errorf("RequestURI = %q; want *", req.RequestURI)
956
	}
957
}
958

959
// Issue 13957
960
func TestTransportDiscardsUnneededConns(t *testing.T) {
961
	setParallel(t)
962
	defer afterTest(t)
963
	cst := newClientServerTest(t, h2Mode, HandlerFunc(func(w ResponseWriter, r *Request) {
964
		fmt.Fprintf(w, "Hello, %v", r.RemoteAddr)
965
	}))
966
	defer cst.close()
967

968
	var numOpen, numClose int32 // atomic
969

970
	tlsConfig := &tls.Config{InsecureSkipVerify: true}
971
	tr := &Transport{
972
		TLSClientConfig: tlsConfig,
973
		DialTLS: func(_, addr string) (net.Conn, error) {
974
			time.Sleep(10 * time.Millisecond)
975
			rc, err := net.Dial("tcp", addr)
976
			if err != nil {
977
				return nil, err
978
			}
979
			atomic.AddInt32(&numOpen, 1)
980
			c := noteCloseConn{rc, func() { atomic.AddInt32(&numClose, 1) }}
981
			return tls.Client(c, tlsConfig), nil
982
		},
983
	}
984
	if err := ExportHttp2ConfigureTransport(tr); err != nil {
985
		t.Fatal(err)
986
	}
987
	defer tr.CloseIdleConnections()
988

989
	c := &Client{Transport: tr}
990

991
	const N = 10
992
	gotBody := make(chan string, N)
993
	var wg sync.WaitGroup
994
	for i := 0; i < N; i++ {
995
		wg.Add(1)
996
		go func() {
997
			defer wg.Done()
998
			resp, err := c.Get(cst.ts.URL)
999
			if err != nil {
1000
				t.Errorf("Get: %v", err)
1001
				return
1002
			}
1003
			defer resp.Body.Close()
1004
			slurp, err := ioutil.ReadAll(resp.Body)
1005
			if err != nil {
1006
				t.Error(err)
1007
			}
1008
			gotBody <- string(slurp)
1009
		}()
1010
	}
1011
	wg.Wait()
1012
	close(gotBody)
1013

1014
	var last string
1015
	for got := range gotBody {
1016
		if last == "" {
1017
			last = got
1018
			continue
1019
		}
1020
		if got != last {
1021
			t.Errorf("Response body changed: %q -> %q", last, got)
1022
		}
1023
	}
1024

1025
	var open, close int32
1026
	for i := 0; i < 150; i++ {
1027
		open, close = atomic.LoadInt32(&numOpen), atomic.LoadInt32(&numClose)
1028
		if open < 1 {
1029
			t.Fatalf("open = %d; want at least", open)
1030
		}
1031
		if close == open-1 {
1032
			// Success
1033
			return
1034
		}
1035
		time.Sleep(10 * time.Millisecond)
1036
	}
1037
	t.Errorf("%d connections opened, %d closed; want %d to close", open, close, open-1)
1038
}
1039

1040
// tests that Transport doesn't retain a pointer to the provided request.
1041
func TestTransportGCRequest_Body_h1(t *testing.T)   { testTransportGCRequest(t, h1Mode, true) }
1042
func TestTransportGCRequest_Body_h2(t *testing.T)   { testTransportGCRequest(t, h2Mode, true) }
1043
func TestTransportGCRequest_NoBody_h1(t *testing.T) { testTransportGCRequest(t, h1Mode, false) }
1044
func TestTransportGCRequest_NoBody_h2(t *testing.T) { testTransportGCRequest(t, h2Mode, false) }
1045
func testTransportGCRequest(t *testing.T, h2, body bool) {
1046
	setParallel(t)
1047
	defer afterTest(t)
1048
	cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
1049
		ioutil.ReadAll(r.Body)
1050
		if body {
1051
			io.WriteString(w, "Hello.")
1052
		}
1053
	}))
1054
	defer cst.close()
1055

1056
	didGC := make(chan struct{})
1057
	(func() {
1058
		body := strings.NewReader("some body")
1059
		req, _ := NewRequest("POST", cst.ts.URL, body)
1060
		runtime.SetFinalizer(req, func(*Request) { close(didGC) })
1061
		res, err := cst.c.Do(req)
1062
		if err != nil {
1063
			t.Fatal(err)
1064
		}
1065
		if _, err := ioutil.ReadAll(res.Body); err != nil {
1066
			t.Fatal(err)
1067
		}
1068
		if err := res.Body.Close(); err != nil {
1069
			t.Fatal(err)
1070
		}
1071
	})()
1072
	timeout := time.NewTimer(5 * time.Second)
1073
	defer timeout.Stop()
1074
	for {
1075
		select {
1076
		case <-didGC:
1077
			return
1078
		case <-time.After(100 * time.Millisecond):
1079
			runtime.GC()
1080
		case <-timeout.C:
1081
			t.Fatal("never saw GC of request")
1082
		}
1083
	}
1084
}
1085

1086
func TestTransportRejectsInvalidHeaders_h1(t *testing.T) {
1087
	testTransportRejectsInvalidHeaders(t, h1Mode)
1088
}
1089
func TestTransportRejectsInvalidHeaders_h2(t *testing.T) {
1090
	testTransportRejectsInvalidHeaders(t, h2Mode)
1091
}
1092
func testTransportRejectsInvalidHeaders(t *testing.T, h2 bool) {
1093
	setParallel(t)
1094
	defer afterTest(t)
1095
	cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
1096
		fmt.Fprintf(w, "Handler saw headers: %q", r.Header)
1097
	}), optQuietLog)
1098
	defer cst.close()
1099
	cst.tr.DisableKeepAlives = true
1100

1101
	tests := []struct {
1102
		key, val string
1103
		ok       bool
1104
	}{
1105
		{"Foo", "capital-key", true}, // verify h2 allows capital keys
1106
		{"Foo", "foo\x00bar", false}, // \x00 byte in value not allowed
1107
		{"Foo", "two\nlines", false}, // \n byte in value not allowed
1108
		{"bogus\nkey", "v", false},   // \n byte also not allowed in key
1109
		{"A space", "v", false},      // spaces in keys not allowed
1110
		{"имя", "v", false},          // key must be ascii
1111
		{"name", "валю", true},       // value may be non-ascii
1112
		{"", "v", false},             // key must be non-empty
1113
		{"k", "", true},              // value may be empty
1114
	}
1115
	for _, tt := range tests {
1116
		dialedc := make(chan bool, 1)
1117
		cst.tr.Dial = func(netw, addr string) (net.Conn, error) {
1118
			dialedc <- true
1119
			return net.Dial(netw, addr)
1120
		}
1121
		req, _ := NewRequest("GET", cst.ts.URL, nil)
1122
		req.Header[tt.key] = []string{tt.val}
1123
		res, err := cst.c.Do(req)
1124
		var body []byte
1125
		if err == nil {
1126
			body, _ = ioutil.ReadAll(res.Body)
1127
			res.Body.Close()
1128
		}
1129
		var dialed bool
1130
		select {
1131
		case <-dialedc:
1132
			dialed = true
1133
		default:
1134
		}
1135

1136
		if !tt.ok && dialed {
1137
			t.Errorf("For key %q, value %q, transport dialed. Expected local failure. Response was: (%v, %v)\nServer replied with: %s", tt.key, tt.val, res, err, body)
1138
		} else if (err == nil) != tt.ok {
1139
			t.Errorf("For key %q, value %q; got err = %v; want ok=%v", tt.key, tt.val, err, tt.ok)
1140
		}
1141
	}
1142
}
1143

1144
func TestInterruptWithPanic_h1(t *testing.T)     { testInterruptWithPanic(t, h1Mode, "boom") }
1145
func TestInterruptWithPanic_h2(t *testing.T)     { testInterruptWithPanic(t, h2Mode, "boom") }
1146
func TestInterruptWithPanic_nil_h1(t *testing.T) { testInterruptWithPanic(t, h1Mode, nil) }
1147
func TestInterruptWithPanic_nil_h2(t *testing.T) { testInterruptWithPanic(t, h2Mode, nil) }
1148
func TestInterruptWithPanic_ErrAbortHandler_h1(t *testing.T) {
1149
	testInterruptWithPanic(t, h1Mode, ErrAbortHandler)
1150
}
1151
func TestInterruptWithPanic_ErrAbortHandler_h2(t *testing.T) {
1152
	testInterruptWithPanic(t, h2Mode, ErrAbortHandler)
1153
}
1154
func testInterruptWithPanic(t *testing.T, h2 bool, panicValue interface{}) {
1155
	setParallel(t)
1156
	const msg = "hello"
1157
	defer afterTest(t)
1158

1159
	testDone := make(chan struct{})
1160
	defer close(testDone)
1161

1162
	var errorLog lockedBytesBuffer
1163
	gotHeaders := make(chan bool, 1)
1164
	cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
1165
		io.WriteString(w, msg)
1166
		w.(Flusher).Flush()
1167

1168
		select {
1169
		case <-gotHeaders:
1170
		case <-testDone:
1171
		}
1172
		panic(panicValue)
1173
	}), func(ts *httptest.Server) {
1174
		ts.Config.ErrorLog = log.New(&errorLog, "", 0)
1175
	})
1176
	defer cst.close()
1177
	res, err := cst.c.Get(cst.ts.URL)
1178
	if err != nil {
1179
		t.Fatal(err)
1180
	}
1181
	gotHeaders <- true
1182
	defer res.Body.Close()
1183
	slurp, err := ioutil.ReadAll(res.Body)
1184
	if string(slurp) != msg {
1185
		t.Errorf("client read %q; want %q", slurp, msg)
1186
	}
1187
	if err == nil {
1188
		t.Errorf("client read all successfully; want some error")
1189
	}
1190
	logOutput := func() string {
1191
		errorLog.Lock()
1192
		defer errorLog.Unlock()
1193
		return errorLog.String()
1194
	}
1195
	wantStackLogged := panicValue != nil && panicValue != ErrAbortHandler
1196

1197
	if err := waitErrCondition(5*time.Second, 10*time.Millisecond, func() error {
1198
		gotLog := logOutput()
1199
		if !wantStackLogged {
1200
			if gotLog == "" {
1201
				return nil
1202
			}
1203
			return fmt.Errorf("want no log output; got: %s", gotLog)
1204
		}
1205
		if gotLog == "" {
1206
			return fmt.Errorf("wanted a stack trace logged; got nothing")
1207
		}
1208
		if !strings.Contains(gotLog, "created by ") && strings.Count(gotLog, "\n") < 6 {
1209
			return fmt.Errorf("output doesn't look like a panic stack trace. Got: %s", gotLog)
1210
		}
1211
		return nil
1212
	}); err != nil {
1213
		t.Fatal(err)
1214
	}
1215
}
1216

1217
type lockedBytesBuffer struct {
1218
	sync.Mutex
1219
	bytes.Buffer
1220
}
1221

1222
func (b *lockedBytesBuffer) Write(p []byte) (int, error) {
1223
	b.Lock()
1224
	defer b.Unlock()
1225
	return b.Buffer.Write(p)
1226
}
1227

1228
// Issue 15366
1229
func TestH12_AutoGzipWithDumpResponse(t *testing.T) {
1230
	h12Compare{
1231
		Handler: func(w ResponseWriter, r *Request) {
1232
			h := w.Header()
1233
			h.Set("Content-Encoding", "gzip")
1234
			h.Set("Content-Length", "23")
1235
			io.WriteString(w, "\x1f\x8b\b\x00\x00\x00\x00\x00\x00\x00s\xf3\xf7\a\x00\xab'\xd4\x1a\x03\x00\x00\x00")
1236
		},
1237
		EarlyCheckResponse: func(proto string, res *Response) {
1238
			if !res.Uncompressed {
1239
				t.Errorf("%s: expected Uncompressed to be set", proto)
1240
			}
1241
			dump, err := httputil.DumpResponse(res, true)
1242
			if err != nil {
1243
				t.Errorf("%s: DumpResponse: %v", proto, err)
1244
				return
1245
			}
1246
			if strings.Contains(string(dump), "Connection: close") {
1247
				t.Errorf("%s: should not see \"Connection: close\" in dump; got:\n%s", proto, dump)
1248
			}
1249
			if !strings.Contains(string(dump), "FOO") {
1250
				t.Errorf("%s: should see \"FOO\" in response; got:\n%s", proto, dump)
1251
			}
1252
		},
1253
	}.run(t)
1254
}
1255

1256
// Issue 14607
1257
func TestCloseIdleConnections_h1(t *testing.T) { testCloseIdleConnections(t, h1Mode) }
1258
func TestCloseIdleConnections_h2(t *testing.T) { testCloseIdleConnections(t, h2Mode) }
1259
func testCloseIdleConnections(t *testing.T, h2 bool) {
1260
	setParallel(t)
1261
	defer afterTest(t)
1262
	cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
1263
		w.Header().Set("X-Addr", r.RemoteAddr)
1264
	}))
1265
	defer cst.close()
1266
	get := func() string {
1267
		res, err := cst.c.Get(cst.ts.URL)
1268
		if err != nil {
1269
			t.Fatal(err)
1270
		}
1271
		res.Body.Close()
1272
		v := res.Header.Get("X-Addr")
1273
		if v == "" {
1274
			t.Fatal("didn't get X-Addr")
1275
		}
1276
		return v
1277
	}
1278
	a1 := get()
1279
	cst.tr.CloseIdleConnections()
1280
	a2 := get()
1281
	if a1 == a2 {
1282
		t.Errorf("didn't close connection")
1283
	}
1284
}
1285

1286
type noteCloseConn struct {
1287
	net.Conn
1288
	closeFunc func()
1289
}
1290

1291
func (x noteCloseConn) Close() error {
1292
	x.closeFunc()
1293
	return x.Conn.Close()
1294
}
1295

1296
type testErrorReader struct{ t *testing.T }
1297

1298
func (r testErrorReader) Read(p []byte) (n int, err error) {
1299
	r.t.Error("unexpected Read call")
1300
	return 0, io.EOF
1301
}
1302

1303
func TestNoSniffExpectRequestBody_h1(t *testing.T) { testNoSniffExpectRequestBody(t, h1Mode) }
1304
func TestNoSniffExpectRequestBody_h2(t *testing.T) { testNoSniffExpectRequestBody(t, h2Mode) }
1305

1306
func testNoSniffExpectRequestBody(t *testing.T, h2 bool) {
1307
	defer afterTest(t)
1308
	cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
1309
		w.WriteHeader(StatusUnauthorized)
1310
	}))
1311
	defer cst.close()
1312

1313
	// Set ExpectContinueTimeout non-zero so RoundTrip won't try to write it.
1314
	cst.tr.ExpectContinueTimeout = 10 * time.Second
1315

1316
	req, err := NewRequest("POST", cst.ts.URL, testErrorReader{t})
1317
	if err != nil {
1318
		t.Fatal(err)
1319
	}
1320
	req.ContentLength = 0 // so transport is tempted to sniff it
1321
	req.Header.Set("Expect", "100-continue")
1322
	res, err := cst.tr.RoundTrip(req)
1323
	if err != nil {
1324
		t.Fatal(err)
1325
	}
1326
	defer res.Body.Close()
1327
	if res.StatusCode != StatusUnauthorized {
1328
		t.Errorf("status code = %v; want %v", res.StatusCode, StatusUnauthorized)
1329
	}
1330
}
1331

1332
func TestServerUndeclaredTrailers_h1(t *testing.T) { testServerUndeclaredTrailers(t, h1Mode) }
1333
func TestServerUndeclaredTrailers_h2(t *testing.T) { testServerUndeclaredTrailers(t, h2Mode) }
1334
func testServerUndeclaredTrailers(t *testing.T, h2 bool) {
1335
	defer afterTest(t)
1336
	cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
1337
		w.Header().Set("Foo", "Bar")
1338
		w.Header().Set("Trailer:Foo", "Baz")
1339
		w.(Flusher).Flush()
1340
		w.Header().Add("Trailer:Foo", "Baz2")
1341
		w.Header().Set("Trailer:Bar", "Quux")
1342
	}))
1343
	defer cst.close()
1344
	res, err := cst.c.Get(cst.ts.URL)
1345
	if err != nil {
1346
		t.Fatal(err)
1347
	}
1348
	if _, err := io.Copy(ioutil.Discard, res.Body); err != nil {
1349
		t.Fatal(err)
1350
	}
1351
	res.Body.Close()
1352
	delete(res.Header, "Date")
1353
	delete(res.Header, "Content-Type")
1354

1355
	if want := (Header{"Foo": {"Bar"}}); !reflect.DeepEqual(res.Header, want) {
1356
		t.Errorf("Header = %#v; want %#v", res.Header, want)
1357
	}
1358
	if want := (Header{"Foo": {"Baz", "Baz2"}, "Bar": {"Quux"}}); !reflect.DeepEqual(res.Trailer, want) {
1359
		t.Errorf("Trailer = %#v; want %#v", res.Trailer, want)
1360
	}
1361
}
1362

1363
func TestBadResponseAfterReadingBody(t *testing.T) {
1364
	defer afterTest(t)
1365
	cst := newClientServerTest(t, false, HandlerFunc(func(w ResponseWriter, r *Request) {
1366
		_, err := io.Copy(ioutil.Discard, r.Body)
1367
		if err != nil {
1368
			t.Fatal(err)
1369
		}
1370
		c, _, err := w.(Hijacker).Hijack()
1371
		if err != nil {
1372
			t.Fatal(err)
1373
		}
1374
		defer c.Close()
1375
		fmt.Fprintln(c, "some bogus crap")
1376
	}))
1377
	defer cst.close()
1378

1379
	closes := 0
1380
	res, err := cst.c.Post(cst.ts.URL, "text/plain", countCloseReader{&closes, strings.NewReader("hello")})
1381
	if err == nil {
1382
		res.Body.Close()
1383
		t.Fatal("expected an error to be returned from Post")
1384
	}
1385
	if closes != 1 {
1386
		t.Errorf("closes = %d; want 1", closes)
1387
	}
1388
}
1389

1390
func TestWriteHeader0_h1(t *testing.T) { testWriteHeader0(t, h1Mode) }
1391
func TestWriteHeader0_h2(t *testing.T) { testWriteHeader0(t, h2Mode) }
1392
func testWriteHeader0(t *testing.T, h2 bool) {
1393
	defer afterTest(t)
1394
	gotpanic := make(chan bool, 1)
1395
	cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
1396
		defer close(gotpanic)
1397
		defer func() {
1398
			if e := recover(); e != nil {
1399
				got := fmt.Sprintf("%T, %v", e, e)
1400
				want := "string, invalid WriteHeader code 0"
1401
				if got != want {
1402
					t.Errorf("unexpected panic value:\n got: %v\nwant: %v\n", got, want)
1403
				}
1404
				gotpanic <- true
1405

1406
				// Set an explicit 503. This also tests that the WriteHeader call panics
1407
				// before it recorded that an explicit value was set and that bogus
1408
				// value wasn't stuck.
1409
				w.WriteHeader(503)
1410
			}
1411
		}()
1412
		w.WriteHeader(0)
1413
	}))
1414
	defer cst.close()
1415
	res, err := cst.c.Get(cst.ts.URL)
1416
	if err != nil {
1417
		t.Fatal(err)
1418
	}
1419
	if res.StatusCode != 503 {
1420
		t.Errorf("Response: %v %q; want 503", res.StatusCode, res.Status)
1421
	}
1422
	if !<-gotpanic {
1423
		t.Error("expected panic in handler")
1424
	}
1425
}
1426

1427
// Issue 23010: don't be super strict checking WriteHeader's code if
1428
// it's not even valid to call WriteHeader then anyway.
1429
func TestWriteHeaderNoCodeCheck_h1(t *testing.T)       { testWriteHeaderAfterWrite(t, h1Mode, false) }
1430
func TestWriteHeaderNoCodeCheck_h1hijack(t *testing.T) { testWriteHeaderAfterWrite(t, h1Mode, true) }
1431
func TestWriteHeaderNoCodeCheck_h2(t *testing.T)       { testWriteHeaderAfterWrite(t, h2Mode, false) }
1432
func testWriteHeaderAfterWrite(t *testing.T, h2, hijack bool) {
1433
	setParallel(t)
1434
	defer afterTest(t)
1435

1436
	var errorLog lockedBytesBuffer
1437
	cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
1438
		if hijack {
1439
			conn, _, _ := w.(Hijacker).Hijack()
1440
			defer conn.Close()
1441
			conn.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 6\r\n\r\nfoo"))
1442
			w.WriteHeader(0) // verify this doesn't panic if there's already output; Issue 23010
1443
			conn.Write([]byte("bar"))
1444
			return
1445
		}
1446
		io.WriteString(w, "foo")
1447
		w.(Flusher).Flush()
1448
		w.WriteHeader(0) // verify this doesn't panic if there's already output; Issue 23010
1449
		io.WriteString(w, "bar")
1450
	}), func(ts *httptest.Server) {
1451
		ts.Config.ErrorLog = log.New(&errorLog, "", 0)
1452
	})
1453
	defer cst.close()
1454
	res, err := cst.c.Get(cst.ts.URL)
1455
	if err != nil {
1456
		t.Fatal(err)
1457
	}
1458
	defer res.Body.Close()
1459
	body, err := ioutil.ReadAll(res.Body)
1460
	if err != nil {
1461
		t.Fatal(err)
1462
	}
1463
	if got, want := string(body), "foobar"; got != want {
1464
		t.Errorf("got = %q; want %q", got, want)
1465
	}
1466

1467
	// Also check the stderr output:
1468
	if h2 {
1469
		// TODO: also emit this log message for HTTP/2?
1470
		// We historically haven't, so don't check.
1471
		return
1472
	}
1473
	gotLog := strings.TrimSpace(errorLog.String())
1474
	wantLog := "http: multiple response.WriteHeader calls"
1475
	if hijack {
1476
		wantLog = "http: response.WriteHeader on hijacked connection"
1477
	}
1478
	if gotLog != wantLog {
1479
		t.Errorf("stderr output = %q; want %q", gotLog, wantLog)
1480
	}
1481
}
1482

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.