cubefs

Форк
0
168 строк · 3.7 Кб
1
// Copyright 2018 The Go Authors. All rights reserved.
2
// Use of this source code is governed by a BSD-style
3
// license that can be found in the LICENSE file.
4

5
package socks
6

7
import (
8
	"context"
9
	"errors"
10
	"io"
11
	"net"
12
	"strconv"
13
	"time"
14
)
15

16
var (
17
	noDeadline   = time.Time{}
18
	aLongTimeAgo = time.Unix(1, 0)
19
)
20

21
func (d *Dialer) connect(ctx context.Context, c net.Conn, address string) (_ net.Addr, ctxErr error) {
22
	host, port, err := splitHostPort(address)
23
	if err != nil {
24
		return nil, err
25
	}
26
	if deadline, ok := ctx.Deadline(); ok && !deadline.IsZero() {
27
		c.SetDeadline(deadline)
28
		defer c.SetDeadline(noDeadline)
29
	}
30
	if ctx != context.Background() {
31
		errCh := make(chan error, 1)
32
		done := make(chan struct{})
33
		defer func() {
34
			close(done)
35
			if ctxErr == nil {
36
				ctxErr = <-errCh
37
			}
38
		}()
39
		go func() {
40
			select {
41
			case <-ctx.Done():
42
				c.SetDeadline(aLongTimeAgo)
43
				errCh <- ctx.Err()
44
			case <-done:
45
				errCh <- nil
46
			}
47
		}()
48
	}
49

50
	b := make([]byte, 0, 6+len(host)) // the size here is just an estimate
51
	b = append(b, Version5)
52
	if len(d.AuthMethods) == 0 || d.Authenticate == nil {
53
		b = append(b, 1, byte(AuthMethodNotRequired))
54
	} else {
55
		ams := d.AuthMethods
56
		if len(ams) > 255 {
57
			return nil, errors.New("too many authentication methods")
58
		}
59
		b = append(b, byte(len(ams)))
60
		for _, am := range ams {
61
			b = append(b, byte(am))
62
		}
63
	}
64
	if _, ctxErr = c.Write(b); ctxErr != nil {
65
		return
66
	}
67

68
	if _, ctxErr = io.ReadFull(c, b[:2]); ctxErr != nil {
69
		return
70
	}
71
	if b[0] != Version5 {
72
		return nil, errors.New("unexpected protocol version " + strconv.Itoa(int(b[0])))
73
	}
74
	am := AuthMethod(b[1])
75
	if am == AuthMethodNoAcceptableMethods {
76
		return nil, errors.New("no acceptable authentication methods")
77
	}
78
	if d.Authenticate != nil {
79
		if ctxErr = d.Authenticate(ctx, c, am); ctxErr != nil {
80
			return
81
		}
82
	}
83

84
	b = b[:0]
85
	b = append(b, Version5, byte(d.cmd), 0)
86
	if ip := net.ParseIP(host); ip != nil {
87
		if ip4 := ip.To4(); ip4 != nil {
88
			b = append(b, AddrTypeIPv4)
89
			b = append(b, ip4...)
90
		} else if ip6 := ip.To16(); ip6 != nil {
91
			b = append(b, AddrTypeIPv6)
92
			b = append(b, ip6...)
93
		} else {
94
			return nil, errors.New("unknown address type")
95
		}
96
	} else {
97
		if len(host) > 255 {
98
			return nil, errors.New("FQDN too long")
99
		}
100
		b = append(b, AddrTypeFQDN)
101
		b = append(b, byte(len(host)))
102
		b = append(b, host...)
103
	}
104
	b = append(b, byte(port>>8), byte(port))
105
	if _, ctxErr = c.Write(b); ctxErr != nil {
106
		return
107
	}
108

109
	if _, ctxErr = io.ReadFull(c, b[:4]); ctxErr != nil {
110
		return
111
	}
112
	if b[0] != Version5 {
113
		return nil, errors.New("unexpected protocol version " + strconv.Itoa(int(b[0])))
114
	}
115
	if cmdErr := Reply(b[1]); cmdErr != StatusSucceeded {
116
		return nil, errors.New("unknown error " + cmdErr.String())
117
	}
118
	if b[2] != 0 {
119
		return nil, errors.New("non-zero reserved field")
120
	}
121
	l := 2
122
	var a Addr
123
	switch b[3] {
124
	case AddrTypeIPv4:
125
		l += net.IPv4len
126
		a.IP = make(net.IP, net.IPv4len)
127
	case AddrTypeIPv6:
128
		l += net.IPv6len
129
		a.IP = make(net.IP, net.IPv6len)
130
	case AddrTypeFQDN:
131
		if _, err := io.ReadFull(c, b[:1]); err != nil {
132
			return nil, err
133
		}
134
		l += int(b[0])
135
	default:
136
		return nil, errors.New("unknown address type " + strconv.Itoa(int(b[3])))
137
	}
138
	if cap(b) < l {
139
		b = make([]byte, l)
140
	} else {
141
		b = b[:l]
142
	}
143
	if _, ctxErr = io.ReadFull(c, b); ctxErr != nil {
144
		return
145
	}
146
	if a.IP != nil {
147
		copy(a.IP, b)
148
	} else {
149
		a.Name = string(b[:len(b)-2])
150
	}
151
	a.Port = int(b[len(b)-2])<<8 | int(b[len(b)-1])
152
	return &a, nil
153
}
154

155
func splitHostPort(address string) (string, int, error) {
156
	host, port, err := net.SplitHostPort(address)
157
	if err != nil {
158
		return "", 0, err
159
	}
160
	portnum, err := strconv.Atoi(port)
161
	if err != nil {
162
		return "", 0, err
163
	}
164
	if 1 > portnum || portnum > 0xffff {
165
		return "", 0, errors.New("port number out of range " + port)
166
	}
167
	return host, portnum, nil
168
}
169

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.