1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24 package waffle.shiro.negotiate;
25
26 import java.io.IOException;
27 import java.util.ArrayList;
28 import java.util.Base64;
29 import java.util.List;
30
31
32
33
34
35
36
37
38
39
40
41
42 import javax.servlet.ServletRequest;
43 import javax.servlet.ServletResponse;
44 import javax.servlet.http.HttpServletRequest;
45 import javax.servlet.http.HttpServletResponse;
46
47 import org.apache.shiro.authc.AuthenticationException;
48 import org.apache.shiro.authc.AuthenticationToken;
49 import org.apache.shiro.subject.Subject;
50 import org.apache.shiro.web.filter.authc.AuthenticatingFilter;
51 import org.apache.shiro.web.filter.authc.FormAuthenticationFilter;
52 import org.apache.shiro.web.util.WebUtils;
53 import org.slf4j.Logger;
54 import org.slf4j.LoggerFactory;
55
56 import waffle.util.AuthorizationHeader;
57 import waffle.util.NtlmServletRequest;
58
59
60
61
62
63
64
65
66
67 public class NegotiateAuthenticationFilter extends AuthenticatingFilter {
68
69
70
71
72 private static final Logger LOGGER = LoggerFactory.getLogger(NegotiateAuthenticationFilter.class);
73
74
75
76
77
78
79 private static final List<String> PROTOCOLS = new ArrayList<>();
80
81
82 private String failureKeyAttribute = FormAuthenticationFilter.DEFAULT_ERROR_KEY_ATTRIBUTE_NAME;
83
84
85 private String rememberMeParam = FormAuthenticationFilter.DEFAULT_REMEMBER_ME_PARAM;
86
87
88
89
90 public NegotiateAuthenticationFilter() {
91 NegotiateAuthenticationFilter.PROTOCOLS.add("Negotiate");
92 NegotiateAuthenticationFilter.PROTOCOLS.add("NTLM");
93 }
94
95
96
97
98
99
100 public String getRememberMeParam() {
101 return this.rememberMeParam;
102 }
103
104
105
106
107
108
109
110
111
112
113
114
115 public void setRememberMeParam(final String value) {
116 this.rememberMeParam = value;
117 }
118
119 @Override
120 protected boolean isRememberMe(final ServletRequest request) {
121 return WebUtils.isTrue(request, this.getRememberMeParam());
122 }
123
124 @Override
125 protected AuthenticationToken createToken(final ServletRequest request, final ServletResponse response) {
126 final String authorization = this.getAuthzHeader(request);
127 final String[] elements = authorization.split(" ", -1);
128 final byte[] inToken = Base64.getDecoder().decode(elements[1]);
129
130
131
132 final String connectionId = NtlmServletRequest.getConnectionId((HttpServletRequest) request);
133 final String securityPackage = elements[0];
134
135
136 final AuthorizationHeader authorizationHeader = new AuthorizationHeader((HttpServletRequest) request);
137 final boolean ntlmPost = authorizationHeader.isNtlmType1PostAuthorizationHeader();
138
139 NegotiateAuthenticationFilter.LOGGER.debug("security package: {}, connection id: {}, ntlmPost: {}",
140 securityPackage, connectionId, Boolean.valueOf(ntlmPost));
141
142 final boolean rememberMe = this.isRememberMe(request);
143 final String host = this.getHost(request);
144
145 return new NegotiateToken(inToken, new byte[0], connectionId, securityPackage, ntlmPost, rememberMe, host);
146 }
147
148 @Override
149 protected boolean onLoginSuccess(final AuthenticationToken token, final Subject subject,
150 final ServletRequest request, final ServletResponse response) throws Exception {
151 request.setAttribute("MY_SUBJECT", ((NegotiateToken) token).getSubject());
152 return true;
153 }
154
155 @Override
156 protected boolean onLoginFailure(final AuthenticationToken token, final AuthenticationException e,
157 final ServletRequest request, final ServletResponse response) {
158 if (e instanceof AuthenticationInProgressException) {
159
160 final String protocol = this.getAuthzHeaderProtocol(request);
161 NegotiateAuthenticationFilter.LOGGER.debug("Negotiation in progress for protocol: {}", protocol);
162 this.sendChallengeDuringNegotiate(protocol, response, ((NegotiateToken) token).getOut());
163 return false;
164 }
165 NegotiateAuthenticationFilter.LOGGER.warn("login exception: {}", e.getMessage());
166
167
168 this.sendChallengeOnFailure(response);
169
170 this.setFailureAttribute(request, e);
171 return true;
172 }
173
174
175
176
177
178
179
180
181
182 protected void setFailureAttribute(final ServletRequest request, final AuthenticationException ae) {
183 final String className = ae.getClass().getName();
184 request.setAttribute(this.getFailureKeyAttribute(), className);
185 }
186
187
188
189
190
191
192 public String getFailureKeyAttribute() {
193 return this.failureKeyAttribute;
194 }
195
196
197
198
199
200
201
202 public void setFailureKeyAttribute(final String value) {
203 this.failureKeyAttribute = value;
204 }
205
206 @Override
207 protected boolean onAccessDenied(final ServletRequest request, final ServletResponse response) throws Exception {
208
209 boolean loggedIn = false;
210
211 if (this.isLoginAttempt(request)) {
212 loggedIn = this.executeLogin(request, response);
213 } else {
214 NegotiateAuthenticationFilter.LOGGER.debug("authorization required, supported protocols: {}",
215 NegotiateAuthenticationFilter.PROTOCOLS);
216 this.sendChallengeInitiateNegotiate(response);
217 }
218 return loggedIn;
219 }
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235 private boolean isLoginAttempt(final ServletRequest request) {
236 final String authzHeader = this.getAuthzHeader(request);
237 return authzHeader != null && this.isLoginAttempt(authzHeader);
238 }
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254 private String getAuthzHeader(final ServletRequest request) {
255 final HttpServletRequest httpRequest = WebUtils.toHttp(request);
256 return httpRequest.getHeader("Authorization");
257 }
258
259
260
261
262
263
264
265
266
267 private String getAuthzHeaderProtocol(final ServletRequest request) {
268 final String authzHeader = this.getAuthzHeader(request);
269 return authzHeader.substring(0, authzHeader.indexOf(' '));
270 }
271
272
273
274
275
276
277
278
279
280
281
282
283 boolean isLoginAttempt(final String authzHeader) {
284 for (final String protocol : NegotiateAuthenticationFilter.PROTOCOLS) {
285 if (authzHeader.toLowerCase().startsWith(protocol.toLowerCase())) {
286 return true;
287 }
288 }
289 return false;
290 }
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305 private void sendChallenge(final List<String> protocols, final ServletResponse response, final byte[] out) {
306 final HttpServletResponse httpResponse = WebUtils.toHttp(response);
307 this.sendAuthenticateHeader(protocols, out, httpResponse);
308 httpResponse.setStatus(HttpServletResponse.SC_UNAUTHORIZED);
309 }
310
311
312
313
314
315
316
317 void sendChallengeInitiateNegotiate(final ServletResponse response) {
318 this.sendChallenge(NegotiateAuthenticationFilter.PROTOCOLS, response, null);
319 }
320
321
322
323
324
325
326
327
328
329
330
331 void sendChallengeDuringNegotiate(final String protocol, final ServletResponse response, final byte[] out) {
332 final List<String> protocolsList = new ArrayList<>();
333 protocolsList.add(protocol);
334 this.sendChallenge(protocolsList, response, out);
335 }
336
337
338
339
340
341
342
343 void sendChallengeOnFailure(final ServletResponse response) {
344 final HttpServletResponse httpResponse = WebUtils.toHttp(response);
345 this.sendUnauthorized(NegotiateAuthenticationFilter.PROTOCOLS, null, httpResponse);
346 httpResponse.setHeader("Connection", "close");
347 try {
348 httpResponse.sendError(HttpServletResponse.SC_UNAUTHORIZED);
349 httpResponse.flushBuffer();
350 } catch (final IOException e) {
351 throw new RuntimeException(e);
352 }
353 }
354
355
356
357
358
359
360
361
362
363
364
365 private void sendAuthenticateHeader(final List<String> protocolsList, final byte[] out,
366 final HttpServletResponse httpResponse) {
367 this.sendUnauthorized(protocolsList, out, httpResponse);
368 httpResponse.setHeader("Connection", "keep-alive");
369 }
370
371
372
373
374
375
376
377
378
379
380
381 private void sendUnauthorized(final List<String> protocols, final byte[] out, final HttpServletResponse response) {
382 for (final String protocol : protocols) {
383 if (out == null || out.length == 0) {
384 response.addHeader("WWW-Authenticate", protocol);
385 } else {
386 response.setHeader("WWW-Authenticate", protocol + " " + Base64.getEncoder().encodeToString(out));
387 }
388 }
389 }
390
391 }