reprogl

Форк
0
/
oauth.go 
246 строк · 6.0 Кб
1
package handlers
2

3
import (
4
	"encoding/json"
5
	"errors"
6
	"fmt"
7
	"net/http"
8
	"time"
9

10
	"github.com/go-chi/chi/v5"
11
	"golang.org/x/oauth2"
12
	"xelbot.com/reprogl/api/backend"
13
	"xelbot.com/reprogl/container"
14
	"xelbot.com/reprogl/models/repositories"
15
	"xelbot.com/reprogl/services/oauth"
16
	"xelbot.com/reprogl/session"
17
	"xelbot.com/reprogl/views"
18
)
19

20
type oauthCallbackState struct {
21
	Status   string `json:"status"`
22
	UserName string `json:"username,omitempty"`
23
	NickName string `json:"nickname,omitempty"`
24
}
25

26
type oauthStateResponse struct {
27
	Status      string `json:"status"`
28
	RedirectURL string `json:"redirect_url,omitempty"`
29
}
30

31
func OAuthLogin(app *container.Application) http.HandlerFunc {
32
	return func(w http.ResponseWriter, r *http.Request) {
33
		providerName := chi.URLParam(r, "provider")
34

35
		app.InfoLog.Println("[OAUTH] start authorization by: " + providerName)
36
		oauthConfig, err := oauth.ConfigByProvider(providerName)
37
		if err != nil {
38
			app.NotFound(w)
39

40
			return
41
		}
42

43
		saveLoginReferer(w, r)
44

45
		state := generateRandomToken()
46
		session.Put(r.Context(), session.OAuthStateKey, state)
47

48
		verifier := oauth2.GenerateVerifier()
49
		session.Put(r.Context(), session.OAuthVerifierKey, verifier)
50

51
		redirectURL := oauthConfig.AuthCodeURL(state, oauth2.S256ChallengeOption(verifier))
52
		app.InfoLog.Println("[OAUTH] redirect to: " + redirectURL)
53

54
		http.Redirect(w, r, redirectURL, http.StatusFound)
55
	}
56
}
57

58
func OAuthCallback(app *container.Application) http.HandlerFunc {
59
	return func(w http.ResponseWriter, r *http.Request) {
60
		providerName := chi.URLParam(r, "provider")
61

62
		app.InfoLog.Println("[OAUTH] callback from: " + providerName)
63
		if !oauth.SupportedProvider(providerName) {
64
			app.NotFound(w)
65

66
			return
67
		}
68

69
		state, _ := session.Pop[string](r.Context(), session.OAuthStateKey)
70
		stateFromRequest := r.FormValue("state")
71

72
		if len(state) == 0 || len(stateFromRequest) == 0 || stateFromRequest != state {
73
			app.InfoLog.Println("[OAUTH] Invalid state")
74
			app.ClientError(w, http.StatusBadRequest)
75

76
			return
77
		}
78

79
		var found bool
80
		verifier, found := session.Pop[string](r.Context(), session.OAuthVerifierKey)
81
		if !found {
82
			app.ServerError(w, errors.New("[OAUTH] PKCE verifier not found"))
83

84
			return
85
		}
86

87
		code := r.FormValue("code")
88
		if len(code) == 0 {
89
			errorCode := r.FormValue("error")
90
			errorDescription := r.FormValue("error_description")
91
			if len(errorCode) > 0 {
92
				app.InfoLog.Printf("[OAUTH] Error code: %s, description: %s\n", errorCode, errorDescription)
93
			} else {
94
				app.InfoLog.Println("[OAUTH] Error: empty code")
95
			}
96
			app.ClientError(w, http.StatusBadRequest)
97

98
			return
99
		}
100

101
		additional := make(map[string]string)
102
		for _, key := range oauth.AdditionalParams(providerName) {
103
			additional[key] = r.FormValue(key)
104
		}
105

106
		requestID := generateRandomToken()
107
		go asyncCallback(requestID, providerName, code, verifier, r.UserAgent(), container.RealRemoteAddress(r), additional, app)
108

109
		templateData := views.NewOauthPendingPageData(requestID)
110
		err := views.WriteTemplate(w, "oauth-pending.gohtml", templateData)
111
		if err != nil {
112
			app.ServerError(w, err)
113

114
			return
115
		}
116
	}
117
}
118

119
func asyncCallback(
120
	requestID,
121
	providerName,
122
	code,
123
	verifier,
124
	userAgent,
125
	ip string,
126
	additional map[string]string,
127
	app *container.Application,
128
) {
129
	cache := app.GetStringCache()
130
	cache.Set(requestID, `{"status":"pending"}`, time.Minute)
131

132
	userData, err := oauth.UserDataByCode(providerName, code, verifier, additional)
133
	if err != nil {
134
		oauthCallbackError(app, requestID, err)
135

136
		return
137
	}
138

139
	userDataDTO := backend.ExternalUserDTO{
140
		UserData:  userData,
141
		UserAgent: userAgent,
142
		IP:        ip,
143
	}
144

145
	apiResponse, err := backend.SendUserData(userDataDTO)
146
	if err != nil {
147
		oauthCallbackError(app, requestID, err)
148

149
		return
150
	}
151

152
	if apiResponse.Violations != nil && len(apiResponse.Violations) > 0 {
153
		errorMessage := "[OAUTH] user validation error:\n"
154
		for _, formError := range apiResponse.Violations {
155
			app.InfoLog.Printf("[OAUTH] validation error: %s - %s\n", formError.Path, formError.Message)
156
			errorMessage += fmt.Sprintf("%s: %s\n", formError.Path, formError.Message)
157
		}
158

159
		oauthCallbackError(app, requestID, err)
160

161
		return
162
	}
163

164
	if apiResponse.User != nil {
165
		oauthState := oauthCallbackState{
166
			Status:   "ok",
167
			UserName: apiResponse.User.Username,
168
			NickName: apiResponse.User.Nickname(),
169
		}
170

171
		jsonBody, err := json.Marshal(oauthState)
172
		if err != nil {
173
			oauthCallbackError(app, requestID, err)
174

175
			return
176
		}
177
		cache.Set(requestID, string(jsonBody), time.Minute)
178
	}
179
}
180

181
func OAuthCheckState(app *container.Application) http.HandlerFunc {
182
	return func(w http.ResponseWriter, r *http.Request) {
183
		requestID := chi.URLParam(r, "request_id")
184

185
		var stateString string
186
		var found bool
187

188
		cache := app.GetStringCache()
189
		if stateString, found = cache.Get(requestID); !found {
190
			app.InfoLog.Println("[OAUTH] requestID not found: " + requestID)
191
			app.NotFound(w)
192

193
			return
194
		}
195

196
		buf := []byte(stateString)
197
		if !json.Valid(buf) {
198
			app.ServerError(w, errors.New("[OAUTH] invalid JSON state"))
199

200
			return
201
		}
202

203
		var oauthState oauthCallbackState
204
		err := json.Unmarshal(buf, &oauthState)
205
		if err != nil {
206
			app.ServerError(w, err)
207

208
			return
209
		}
210

211
		responseData := oauthStateResponse{
212
			Status: oauthState.Status,
213
		}
214

215
		if oauthState.Status == "ok" && len(oauthState.UserName) > 0 {
216
			session.Put(r.Context(), session.FlashSuccessKey, fmt.Sprintf("Привет, %s :)", oauthState.NickName))
217

218
			repo := repositories.UserRepository{DB: app.DB}
219
			user, err := repo.GetLoggedUserByUsername(oauthState.UserName)
220
			if err != nil {
221
				app.ServerError(w, err)
222

223
				return
224
			}
225

226
			app.InfoLog.Printf("[OAUTH] success for \"%s\"\n", user.Username)
227
			authSuccess(user, app, container.RealRemoteAddress(r), r.Context())
228

229
			var redirectUrl string
230
			if redirectUrl, found = popLoginReferer(w, r); !found {
231
				redirectUrl = "/"
232
			}
233

234
			responseData.RedirectURL = redirectUrl
235
		}
236

237
		jsonResponse(w, http.StatusOK, responseData)
238
	}
239
}
240

241
func oauthCallbackError(app *container.Application, requestID string, err error) {
242
	app.LogError(err)
243

244
	cache := app.GetStringCache()
245
	cache.Set(requestID, `{"status":"error"}`, time.Minute)
246
}
247

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.