1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24 package waffle.shiro.negotiate;
25
26 import java.io.IOException;
27 import java.util.ArrayList;
28 import java.util.Base64;
29 import java.util.List;
30 import java.util.Locale;
31
32 import javax.servlet.ServletRequest;
33 import javax.servlet.ServletResponse;
34 import javax.servlet.http.HttpServletRequest;
35 import javax.servlet.http.HttpServletResponse;
36
37 import org.apache.shiro.authc.AuthenticationException;
38 import org.apache.shiro.authc.AuthenticationToken;
39 import org.apache.shiro.subject.Subject;
40 import org.apache.shiro.web.filter.authc.AuthenticatingFilter;
41 import org.apache.shiro.web.filter.authc.FormAuthenticationFilter;
42 import org.apache.shiro.web.util.WebUtils;
43 import org.slf4j.Logger;
44 import org.slf4j.LoggerFactory;
45
46 import waffle.util.AuthorizationHeader;
47 import waffle.util.NtlmServletRequest;
48
49
50
51
52
53
54
55
56
57
58 public class NegotiateAuthenticationFilter extends AuthenticatingFilter {
59
60
61
62
63 private static final Logger LOGGER = LoggerFactory.getLogger(NegotiateAuthenticationFilter.class);
64
65
66
67
68
69
70 private static final List<String> PROTOCOLS = new ArrayList<>();
71
72
73 private String failureKeyAttribute = FormAuthenticationFilter.DEFAULT_ERROR_KEY_ATTRIBUTE_NAME;
74
75
76 private String rememberMeParam = FormAuthenticationFilter.DEFAULT_REMEMBER_ME_PARAM;
77
78
79
80
81 public NegotiateAuthenticationFilter() {
82 NegotiateAuthenticationFilter.PROTOCOLS.add("Negotiate");
83 NegotiateAuthenticationFilter.PROTOCOLS.add("NTLM");
84 }
85
86
87
88
89
90
91 public String getRememberMeParam() {
92 return this.rememberMeParam;
93 }
94
95
96
97
98
99
100
101
102
103
104
105
106 public void setRememberMeParam(final String value) {
107 this.rememberMeParam = value;
108 }
109
110 @Override
111 protected boolean isRememberMe(final ServletRequest request) {
112 return WebUtils.isTrue(request, this.getRememberMeParam());
113 }
114
115 @Override
116 protected AuthenticationToken createToken(final ServletRequest request, final ServletResponse response) {
117 final String authorization = this.getAuthzHeader(request);
118 final String[] elements = authorization.split(" ", -1);
119 final byte[] inToken = Base64.getDecoder().decode(elements[1]);
120
121
122
123 final String connectionId = NtlmServletRequest.getConnectionId((HttpServletRequest) request);
124 final String securityPackage = elements[0];
125
126
127 final AuthorizationHeader authorizationHeader = new AuthorizationHeader((HttpServletRequest) request);
128 final boolean ntlmPost = authorizationHeader.isNtlmType1PostAuthorizationHeader();
129
130 NegotiateAuthenticationFilter.LOGGER.debug("security package: {}, connection id: {}, ntlmPost: {}",
131 securityPackage, connectionId, Boolean.valueOf(ntlmPost));
132
133 final boolean rememberMe = this.isRememberMe(request);
134 final String host = this.getHost(request);
135
136 return new NegotiateToken(inToken, new byte[0], connectionId, securityPackage, ntlmPost, rememberMe, host);
137 }
138
139 @Override
140 protected boolean onLoginSuccess(final AuthenticationToken token, final Subject subject,
141 final ServletRequest request, final ServletResponse response) throws Exception {
142 request.setAttribute("MY_SUBJECT", ((NegotiateToken) token).getSubject());
143 return true;
144 }
145
146 @Override
147 protected boolean onLoginFailure(final AuthenticationToken token, final AuthenticationException e,
148 final ServletRequest request, final ServletResponse response) {
149 if (e instanceof AuthenticationInProgressException) {
150
151 final String protocol = this.getAuthzHeaderProtocol(request);
152 NegotiateAuthenticationFilter.LOGGER.debug("Negotiation in progress for protocol: {}", protocol);
153 this.sendChallengeDuringNegotiate(protocol, response, ((NegotiateToken) token).getOut());
154 return false;
155 }
156 NegotiateAuthenticationFilter.LOGGER.warn("login exception: {}", e.getMessage());
157
158
159 this.sendChallengeOnFailure(response);
160
161 this.setFailureAttribute(request, e);
162 return true;
163 }
164
165
166
167
168
169
170
171
172
173 protected void setFailureAttribute(final ServletRequest request, final AuthenticationException ae) {
174 final String className = ae.getClass().getName();
175 request.setAttribute(this.getFailureKeyAttribute(), className);
176 }
177
178
179
180
181
182
183 public String getFailureKeyAttribute() {
184 return this.failureKeyAttribute;
185 }
186
187
188
189
190
191
192
193 public void setFailureKeyAttribute(final String value) {
194 this.failureKeyAttribute = value;
195 }
196
197 @Override
198 protected boolean onAccessDenied(final ServletRequest request, final ServletResponse response) throws Exception {
199
200 boolean loggedIn = false;
201
202 if (this.isLoginAttempt(request)) {
203 loggedIn = this.executeLogin(request, response);
204 } else {
205 NegotiateAuthenticationFilter.LOGGER.debug("authorization required, supported protocols: {}",
206 NegotiateAuthenticationFilter.PROTOCOLS);
207 this.sendChallengeInitiateNegotiate(response);
208 }
209 return loggedIn;
210 }
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226 private String getAuthzHeader(final ServletRequest request) {
227 final HttpServletRequest httpRequest = WebUtils.toHttp(request);
228 return httpRequest.getHeader("Authorization");
229 }
230
231
232
233
234
235
236
237
238
239 private String getAuthzHeaderProtocol(final ServletRequest request) {
240 final String authzHeader = this.getAuthzHeader(request);
241 return authzHeader.substring(0, authzHeader.indexOf(' '));
242 }
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258 private boolean isLoginAttempt(final ServletRequest request) {
259 final String authzHeader = this.getAuthzHeader(request);
260 return authzHeader != null && this.isLoginAttempt(authzHeader);
261 }
262
263
264
265
266
267
268
269
270
271
272
273
274 boolean isLoginAttempt(final String authzHeader) {
275 for (final String protocol : NegotiateAuthenticationFilter.PROTOCOLS) {
276 if (authzHeader.toLowerCase(Locale.ENGLISH).startsWith(protocol.toLowerCase(Locale.ENGLISH))) {
277 return true;
278 }
279 }
280 return false;
281 }
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296 private void sendChallenge(final List<String> protocols, final ServletResponse response, final byte[] out) {
297 final HttpServletResponse httpResponse = WebUtils.toHttp(response);
298 this.sendAuthenticateHeader(protocols, out, httpResponse);
299 httpResponse.setStatus(HttpServletResponse.SC_UNAUTHORIZED);
300 }
301
302
303
304
305
306
307
308 void sendChallengeInitiateNegotiate(final ServletResponse response) {
309 this.sendChallenge(NegotiateAuthenticationFilter.PROTOCOLS, response, null);
310 }
311
312
313
314
315
316
317
318
319
320
321
322 void sendChallengeDuringNegotiate(final String protocol, final ServletResponse response, final byte[] out) {
323 final List<String> protocolsList = new ArrayList<>();
324 protocolsList.add(protocol);
325 this.sendChallenge(protocolsList, response, out);
326 }
327
328
329
330
331
332
333
334 void sendChallengeOnFailure(final ServletResponse response) {
335 final HttpServletResponse httpResponse = WebUtils.toHttp(response);
336 this.sendUnauthorized(NegotiateAuthenticationFilter.PROTOCOLS, null, httpResponse);
337 httpResponse.setHeader("Connection", "close");
338 try {
339 httpResponse.sendError(HttpServletResponse.SC_UNAUTHORIZED);
340 httpResponse.flushBuffer();
341 } catch (final IOException e) {
342 throw new RuntimeException(e);
343 }
344 }
345
346
347
348
349
350
351
352
353
354
355
356 private void sendAuthenticateHeader(final List<String> protocolsList, final byte[] out,
357 final HttpServletResponse httpResponse) {
358 this.sendUnauthorized(protocolsList, out, httpResponse);
359 httpResponse.setHeader("Connection", "keep-alive");
360 }
361
362
363
364
365
366
367
368
369
370
371
372 private void sendUnauthorized(final List<String> protocols, final byte[] out, final HttpServletResponse response) {
373 for (final String protocol : protocols) {
374 if (out == null || out.length == 0) {
375 response.addHeader("WWW-Authenticate", protocol);
376 } else {
377 response.setHeader("WWW-Authenticate", protocol + " " + Base64.getEncoder().encodeToString(out));
378 }
379 }
380 }
381
382 }