View Javadoc
1   /*
2    * MIT License
3    *
4    * Copyright (c) 2010-2024 The Waffle Project Contributors: https://github.com/Waffle/waffle/graphs/contributors
5    *
6    * Permission is hereby granted, free of charge, to any person obtaining a copy
7    * of this software and associated documentation files (the "Software"), to deal
8    * in the Software without restriction, including without limitation the rights
9    * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10   * copies of the Software, and to permit persons to whom the Software is
11   * furnished to do so, subject to the following conditions:
12   *
13   * The above copyright notice and this permission notice shall be included in all
14   * copies or substantial portions of the Software.
15   *
16   * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17   * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18   * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19   * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20   * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21   * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22   * SOFTWARE.
23   */
24  package waffle.servlet.spi;
25  
26  import java.io.IOException;
27  import java.security.InvalidParameterException;
28  import java.util.ArrayList;
29  import java.util.Base64;
30  import java.util.List;
31  
32  import javax.servlet.http.HttpServletRequest;
33  import javax.servlet.http.HttpServletResponse;
34  
35  import org.slf4j.Logger;
36  import org.slf4j.LoggerFactory;
37  
38  import waffle.util.AuthorizationHeader;
39  import waffle.util.NtlmServletRequest;
40  import waffle.windows.auth.IWindowsAuthProvider;
41  import waffle.windows.auth.IWindowsIdentity;
42  import waffle.windows.auth.IWindowsSecurityContext;
43  
44  /**
45   * A negotiate security filter provider.
46   */
47  public class NegotiateSecurityFilterProvider implements SecurityFilterProvider {
48  
49      /** The Constant LOGGER. */
50      private static final Logger LOGGER = LoggerFactory.getLogger(NegotiateSecurityFilterProvider.class);
51  
52      /** The Constant WWW_AUTHENTICATE. */
53      private static final String WWW_AUTHENTICATE = "WWW-Authenticate";
54  
55      /** The Constant PROTOCOLS. */
56      private static final String PROTOCOLS = "protocols";
57  
58      /** The Constant NEGOTIATE. */
59      private static final String NEGOTIATE = "Negotiate";
60  
61      /** The Constant NTLM. */
62      private static final String NTLM = "NTLM";
63  
64      /** The protocols. */
65      private List<String> protocolsList = new ArrayList<>();
66  
67      /** The auth. */
68      private final IWindowsAuthProvider auth;
69  
70      /**
71       * Instantiates a new negotiate security filter provider.
72       *
73       * @param newAuthProvider
74       *            the new auth provider
75       */
76      public NegotiateSecurityFilterProvider(final IWindowsAuthProvider newAuthProvider) {
77          this.auth = newAuthProvider;
78          this.protocolsList.add(NegotiateSecurityFilterProvider.NEGOTIATE);
79          this.protocolsList.add(NegotiateSecurityFilterProvider.NTLM);
80      }
81  
82      /**
83       * Gets the protocols.
84       *
85       * @return the protocols
86       */
87      public List<String> getProtocols() {
88          return this.protocolsList;
89      }
90  
91      /**
92       * Sets the protocols.
93       *
94       * @param values
95       *            the new protocols
96       */
97      public void setProtocols(final List<String> values) {
98          this.protocolsList = values;
99      }
100 
101     @Override
102     public void sendUnauthorized(final HttpServletResponse response) {
103         for (final String protocol : this.protocolsList) {
104             response.addHeader(NegotiateSecurityFilterProvider.WWW_AUTHENTICATE, protocol);
105         }
106     }
107 
108     @Override
109     public boolean isPrincipalException(final HttpServletRequest request) {
110         final AuthorizationHeader authorizationHeader = new AuthorizationHeader(request);
111         final boolean ntlmPost = authorizationHeader.isNtlmType1PostAuthorizationHeader();
112         NegotiateSecurityFilterProvider.LOGGER.debug("authorization: {}, ntlm post: {}", authorizationHeader,
113                 Boolean.valueOf(ntlmPost));
114         return ntlmPost;
115     }
116 
117     @Override
118     public IWindowsIdentity doFilter(final HttpServletRequest request, final HttpServletResponse response)
119             throws IOException {
120 
121         final AuthorizationHeader authorizationHeader = new AuthorizationHeader(request);
122         final boolean ntlmPost = authorizationHeader.isNtlmType1PostAuthorizationHeader();
123 
124         // maintain a connection-based session for NTLM tokens
125         final String connectionId = NtlmServletRequest.getConnectionId(request);
126         final String securityPackage = authorizationHeader.getSecurityPackage();
127         NegotiateSecurityFilterProvider.LOGGER.debug("security package: {}, connection id: {}", securityPackage,
128                 connectionId);
129 
130         if (ntlmPost) {
131             // type 2 NTLM authentication message received
132             this.auth.resetSecurityToken(connectionId);
133         }
134 
135         final byte[] tokenBuffer = authorizationHeader.getTokenBytes();
136         NegotiateSecurityFilterProvider.LOGGER.debug("token buffer: {} byte(s)", Integer.valueOf(tokenBuffer.length));
137         final IWindowsSecurityContext securityContext = this.auth.acceptSecurityToken(connectionId, tokenBuffer,
138                 securityPackage);
139 
140         final byte[] continueTokenBytes = securityContext.getToken();
141         if (continueTokenBytes != null && continueTokenBytes.length > 0) {
142             final String continueToken = Base64.getEncoder().encodeToString(continueTokenBytes);
143             NegotiateSecurityFilterProvider.LOGGER.debug("continue token: {}", continueToken);
144             response.addHeader(NegotiateSecurityFilterProvider.WWW_AUTHENTICATE, securityPackage + " " + continueToken);
145         }
146 
147         NegotiateSecurityFilterProvider.LOGGER.debug("continue required: {}",
148                 Boolean.valueOf(securityContext.isContinue()));
149         if (securityContext.isContinue()) {
150             response.setHeader("Connection", "keep-alive");
151             response.setStatus(HttpServletResponse.SC_UNAUTHORIZED);
152             response.flushBuffer();
153             return null;
154         }
155 
156         final IWindowsIdentity identity = securityContext.getIdentity();
157         securityContext.dispose();
158         return identity;
159     }
160 
161     @Override
162     public boolean isSecurityPackageSupported(final String securityPackage) {
163         for (final String protocol : this.protocolsList) {
164             if (protocol.equalsIgnoreCase(securityPackage)) {
165                 return true;
166             }
167         }
168         return false;
169     }
170 
171     @Override
172     public void initParameter(final String parameterName, final String parameterValue) {
173         if (NegotiateSecurityFilterProvider.PROTOCOLS.equals(parameterName)) {
174             this.protocolsList = new ArrayList<>();
175             final String[] protocolNames = parameterValue.split("\\s+", -1);
176             for (String protocolName : protocolNames) {
177                 protocolName = protocolName.trim();
178                 if (protocolName.length() > 0) {
179                     NegotiateSecurityFilterProvider.LOGGER.debug("init protocol: {}", protocolName);
180                     if (NegotiateSecurityFilterProvider.NEGOTIATE.equals(protocolName)
181                             || NegotiateSecurityFilterProvider.NTLM.equals(protocolName)) {
182                         this.protocolsList.add(protocolName);
183                     } else {
184                         NegotiateSecurityFilterProvider.LOGGER.error("unsupported protocol: {}", protocolName);
185                         throw new RuntimeException("Unsupported protocol: " + protocolName);
186                     }
187                 }
188             }
189         } else {
190             throw new InvalidParameterException(parameterName);
191         }
192     }
193 }