|
40 | 40 | import org.springframework.http.HttpMethod; |
41 | 41 | import org.springframework.lang.Nullable; |
42 | 42 | import org.springframework.security.config.ObjectPostProcessor; |
| 43 | +import org.springframework.security.config.annotation.web.ServletRegistrationsSupport.RegistrationMapping; |
43 | 44 | import org.springframework.security.config.annotation.web.configurers.AbstractConfigAttributeRequestMatcherRegistry; |
44 | 45 | import org.springframework.security.web.servlet.util.matcher.MvcRequestMatcher; |
45 | 46 | import org.springframework.security.web.util.matcher.AntPathRequestMatcher; |
@@ -235,103 +236,31 @@ private boolean anyPathsDontStartWithLeadingSlash(String... patterns) { |
235 | 236 | } |
236 | 237 |
|
237 | 238 | private RequestMatcher resolve(AntPathRequestMatcher ant, MvcRequestMatcher mvc, ServletContext servletContext) { |
238 | | - Map<String, ? extends ServletRegistration> registrations = mappableServletRegistrations(servletContext); |
239 | | - if (registrations.isEmpty()) { |
| 239 | + ServletRegistrationsSupport registrations = new ServletRegistrationsSupport(servletContext); |
| 240 | + Collection<RegistrationMapping> mappings = registrations.mappings(); |
| 241 | + if (mappings.isEmpty()) { |
240 | 242 | return new DispatcherServletDelegatingRequestMatcher(ant, mvc, new MockMvcRequestMatcher()); |
241 | 243 | } |
242 | | - if (!hasDispatcherServlet(registrations)) { |
| 244 | + Collection<RegistrationMapping> dispatcherServletMappings = registrations.dispatcherServletMappings(); |
| 245 | + if (dispatcherServletMappings.isEmpty()) { |
243 | 246 | return new DispatcherServletDelegatingRequestMatcher(ant, mvc, new MockMvcRequestMatcher()); |
244 | 247 | } |
245 | | - ServletRegistration dispatcherServlet = requireOneRootDispatcherServlet(registrations); |
246 | | - if (dispatcherServlet != null) { |
247 | | - if (registrations.size() == 1) { |
248 | | - return mvc; |
249 | | - } |
250 | | - return new DispatcherServletDelegatingRequestMatcher(ant, mvc, servletContext); |
| 248 | + if (dispatcherServletMappings.size() > 1) { |
| 249 | + String errorMessage = computeErrorMessage(servletContext.getServletRegistrations().values()); |
| 250 | + throw new IllegalArgumentException(errorMessage); |
251 | 251 | } |
252 | | - dispatcherServlet = requireOnlyPathMappedDispatcherServlet(registrations); |
253 | | - if (dispatcherServlet != null) { |
254 | | - String mapping = dispatcherServlet.getMappings().iterator().next(); |
255 | | - mvc.setServletPath(mapping.substring(0, mapping.length() - 2)); |
256 | | - return mvc; |
257 | | - } |
258 | | - String errorMessage = computeErrorMessage(registrations.values()); |
259 | | - throw new IllegalArgumentException(errorMessage); |
260 | | - } |
261 | | - |
262 | | - private Map<String, ? extends ServletRegistration> mappableServletRegistrations(ServletContext servletContext) { |
263 | | - Map<String, ServletRegistration> mappable = new LinkedHashMap<>(); |
264 | | - for (Map.Entry<String, ? extends ServletRegistration> entry : servletContext.getServletRegistrations() |
265 | | - .entrySet()) { |
266 | | - if (!entry.getValue().getMappings().isEmpty()) { |
267 | | - mappable.put(entry.getKey(), entry.getValue()); |
268 | | - } |
| 252 | + RegistrationMapping dispatcherServlet = dispatcherServletMappings.iterator().next(); |
| 253 | + if (mappings.size() > 1 && !dispatcherServlet.isDefault()) { |
| 254 | + String errorMessage = computeErrorMessage(servletContext.getServletRegistrations().values()); |
| 255 | + throw new IllegalArgumentException(errorMessage); |
269 | 256 | } |
270 | | - return mappable; |
271 | | - } |
272 | | - |
273 | | - private boolean hasDispatcherServlet(Map<String, ? extends ServletRegistration> registrations) { |
274 | | - if (registrations == null) { |
275 | | - return false; |
276 | | - } |
277 | | - for (ServletRegistration registration : registrations.values()) { |
278 | | - if (isDispatcherServlet(registration)) { |
279 | | - return true; |
280 | | - } |
281 | | - } |
282 | | - return false; |
283 | | - } |
284 | | - |
285 | | - private ServletRegistration requireOneRootDispatcherServlet( |
286 | | - Map<String, ? extends ServletRegistration> registrations) { |
287 | | - ServletRegistration rootDispatcherServlet = null; |
288 | | - for (ServletRegistration registration : registrations.values()) { |
289 | | - if (!isDispatcherServlet(registration)) { |
290 | | - continue; |
291 | | - } |
292 | | - if (registration.getMappings().size() > 1) { |
293 | | - return null; |
294 | | - } |
295 | | - if (!"/".equals(registration.getMappings().iterator().next())) { |
296 | | - return null; |
297 | | - } |
298 | | - rootDispatcherServlet = registration; |
299 | | - } |
300 | | - return rootDispatcherServlet; |
301 | | - } |
302 | | - |
303 | | - private ServletRegistration requireOnlyPathMappedDispatcherServlet( |
304 | | - Map<String, ? extends ServletRegistration> registrations) { |
305 | | - ServletRegistration pathDispatcherServlet = null; |
306 | | - for (ServletRegistration registration : registrations.values()) { |
307 | | - if (!isDispatcherServlet(registration)) { |
308 | | - return null; |
309 | | - } |
310 | | - if (registration.getMappings().size() > 1) { |
311 | | - return null; |
312 | | - } |
313 | | - String mapping = registration.getMappings().iterator().next(); |
314 | | - if (!mapping.startsWith("/") || !mapping.endsWith("/*")) { |
315 | | - return null; |
316 | | - } |
317 | | - if (pathDispatcherServlet != null) { |
318 | | - return null; |
| 257 | + if (dispatcherServlet.isDefault()) { |
| 258 | + if (mappings.size() == 1) { |
| 259 | + return mvc; |
319 | 260 | } |
320 | | - pathDispatcherServlet = registration; |
321 | | - } |
322 | | - return pathDispatcherServlet; |
323 | | - } |
324 | | - |
325 | | - private boolean isDispatcherServlet(ServletRegistration registration) { |
326 | | - Class<?> dispatcherServlet = ClassUtils.resolveClassName("org.springframework.web.servlet.DispatcherServlet", |
327 | | - null); |
328 | | - try { |
329 | | - Class<?> clazz = Class.forName(registration.getClassName()); |
330 | | - return dispatcherServlet.isAssignableFrom(clazz); |
331 | | - } |
332 | | - catch (ClassNotFoundException ex) { |
333 | | - return false; |
| 261 | + return new DispatcherServletDelegatingRequestMatcher(ant, mvc); |
334 | 262 | } |
| 263 | + return mvc; |
335 | 264 | } |
336 | 265 |
|
337 | 266 | private static String computeErrorMessage(Collection<? extends ServletRegistration> registrations) { |
|
0 commit comments