1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24 package waffle.servlet.spi;
25
26 import java.io.IOException;
27 import java.security.InvalidParameterException;
28 import java.util.ArrayList;
29 import java.util.Base64;
30 import java.util.List;
31
32 import javax.servlet.http.HttpServletRequest;
33 import javax.servlet.http.HttpServletResponse;
34
35 import org.slf4j.Logger;
36 import org.slf4j.LoggerFactory;
37
38 import waffle.util.AuthorizationHeader;
39 import waffle.util.NtlmServletRequest;
40 import waffle.windows.auth.IWindowsAuthProvider;
41 import waffle.windows.auth.IWindowsIdentity;
42 import waffle.windows.auth.IWindowsSecurityContext;
43
44
45
46
47 public class NegotiateSecurityFilterProvider implements SecurityFilterProvider {
48
49
50 private static final Logger LOGGER = LoggerFactory.getLogger(NegotiateSecurityFilterProvider.class);
51
52
53 private static final String WWW_AUTHENTICATE = "WWW-Authenticate";
54
55
56 private static final String PROTOCOLS = "protocols";
57
58
59 private static final String NEGOTIATE = "Negotiate";
60
61
62 private static final String NTLM = "NTLM";
63
64
65 private List<String> protocolsList = new ArrayList<>();
66
67
68 private final IWindowsAuthProvider auth;
69
70
71
72
73
74
75
76 public NegotiateSecurityFilterProvider(final IWindowsAuthProvider newAuthProvider) {
77 this.auth = newAuthProvider;
78 this.protocolsList.add(NegotiateSecurityFilterProvider.NEGOTIATE);
79 this.protocolsList.add(NegotiateSecurityFilterProvider.NTLM);
80 }
81
82
83
84
85
86
87 public List<String> getProtocols() {
88 return this.protocolsList;
89 }
90
91
92
93
94
95
96
97 public void setProtocols(final List<String> values) {
98 this.protocolsList = values;
99 }
100
101 @Override
102 public void sendUnauthorized(final HttpServletResponse response) {
103 for (final String protocol : this.protocolsList) {
104 response.addHeader(NegotiateSecurityFilterProvider.WWW_AUTHENTICATE, protocol);
105 }
106 }
107
108 @Override
109 public boolean isPrincipalException(final HttpServletRequest request) {
110 final AuthorizationHeader authorizationHeader = new AuthorizationHeader(request);
111 final boolean ntlmPost = authorizationHeader.isNtlmType1PostAuthorizationHeader();
112 NegotiateSecurityFilterProvider.LOGGER.debug("authorization: {}, ntlm post: {}", authorizationHeader,
113 Boolean.valueOf(ntlmPost));
114 return ntlmPost;
115 }
116
117 @Override
118 public IWindowsIdentity doFilter(final HttpServletRequest request, final HttpServletResponse response)
119 throws IOException {
120
121 final AuthorizationHeader authorizationHeader = new AuthorizationHeader(request);
122 final boolean ntlmPost = authorizationHeader.isNtlmType1PostAuthorizationHeader();
123
124
125 final String connectionId = NtlmServletRequest.getConnectionId(request);
126 final String securityPackage = authorizationHeader.getSecurityPackage();
127 NegotiateSecurityFilterProvider.LOGGER.debug("security package: {}, connection id: {}", securityPackage,
128 connectionId);
129
130 if (ntlmPost) {
131
132 this.auth.resetSecurityToken(connectionId);
133 }
134
135 final byte[] tokenBuffer = authorizationHeader.getTokenBytes();
136 NegotiateSecurityFilterProvider.LOGGER.debug("token buffer: {} byte(s)", Integer.valueOf(tokenBuffer.length));
137 final IWindowsSecurityContext securityContext = this.auth.acceptSecurityToken(connectionId, tokenBuffer,
138 securityPackage);
139
140 final byte[] continueTokenBytes = securityContext.getToken();
141 if (continueTokenBytes != null && continueTokenBytes.length > 0) {
142 final String continueToken = Base64.getEncoder().encodeToString(continueTokenBytes);
143 NegotiateSecurityFilterProvider.LOGGER.debug("continue token: {}", continueToken);
144 response.addHeader(NegotiateSecurityFilterProvider.WWW_AUTHENTICATE, securityPackage + " " + continueToken);
145 }
146
147 NegotiateSecurityFilterProvider.LOGGER.debug("continue required: {}",
148 Boolean.valueOf(securityContext.isContinue()));
149 if (securityContext.isContinue()) {
150 response.setHeader("Connection", "keep-alive");
151 response.setStatus(HttpServletResponse.SC_UNAUTHORIZED);
152 response.flushBuffer();
153 return null;
154 }
155
156 final IWindowsIdentity identity = securityContext.getIdentity();
157 securityContext.dispose();
158 return identity;
159 }
160
161 @Override
162 public boolean isSecurityPackageSupported(final String securityPackage) {
163 for (final String protocol : this.protocolsList) {
164 if (protocol.equalsIgnoreCase(securityPackage)) {
165 return true;
166 }
167 }
168 return false;
169 }
170
171 @Override
172 public void initParameter(final String parameterName, final String parameterValue) {
173 if (NegotiateSecurityFilterProvider.PROTOCOLS.equals(parameterName)) {
174 this.protocolsList = new ArrayList<>();
175 final String[] protocolNames = parameterValue.split("\\s+", -1);
176 for (String protocolName : protocolNames) {
177 protocolName = protocolName.trim();
178 if (protocolName.length() > 0) {
179 NegotiateSecurityFilterProvider.LOGGER.debug("init protocol: {}", protocolName);
180 if (NegotiateSecurityFilterProvider.NEGOTIATE.equals(protocolName)
181 || NegotiateSecurityFilterProvider.NTLM.equals(protocolName)) {
182 this.protocolsList.add(protocolName);
183 } else {
184 NegotiateSecurityFilterProvider.LOGGER.error("unsupported protocol: {}", protocolName);
185 throw new RuntimeException("Unsupported protocol: " + protocolName);
186 }
187 }
188 }
189 } else {
190 throw new InvalidParameterException(parameterName);
191 }
192 }
193 }