diff --git a/src/main/src/main/java/org/geoserver/security/filter/GeoServerRequestHeaderAuthenticationFilter.java b/src/main/src/main/java/org/geoserver/security/filter/GeoServerRequestHeaderAuthenticationFilter.java index cbb6083bc7b..f6ec84e6415 100644 --- a/src/main/src/main/java/org/geoserver/security/filter/GeoServerRequestHeaderAuthenticationFilter.java +++ b/src/main/src/main/java/org/geoserver/security/filter/GeoServerRequestHeaderAuthenticationFilter.java @@ -7,12 +7,19 @@ package org.geoserver.security.filter; import java.io.IOException; +import javax.servlet.FilterChain; +import javax.servlet.ServletException; +import javax.servlet.ServletRequest; +import javax.servlet.ServletResponse; import javax.servlet.http.HttpServletRequest; import org.geoserver.security.config.RequestHeaderAuthenticationFilterConfig; import org.geoserver.security.config.SecurityNamedServiceConfig; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.web.authentication.preauth.PreAuthenticatedAuthenticationToken; /** - * J2EE Authentication Filter + * Request header Authentication Filter * * @author mcr */ @@ -38,6 +45,22 @@ public void initializeFromConfig(SecurityNamedServiceConfig config) throws IOExc setPrincipalHeaderAttribute(authConfig.getPrincipalHeaderAttribute()); } + @Override + public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) + throws IOException, ServletException { + String principalName = getPreAuthenticatedPrincipalName((HttpServletRequest) request); + Authentication preAuth = SecurityContextHolder.getContext().getAuthentication(); + // If a pre-auth token exists but the request has no principal name anymore + // or differs from the one being sent in the headers, clear the + // security context, or else the user will remain authenticated. + if (preAuth instanceof PreAuthenticatedAuthenticationToken + && ((null == principalName) + || (!principalName.equals(preAuth.getPrincipal().toString())))) { + SecurityContextHolder.clearContext(); + } + super.doFilter(request, response, chain); + } + @Override protected String getPreAuthenticatedPrincipalName(HttpServletRequest request) { return request.getHeader(getPrincipalHeaderAttribute()); diff --git a/src/main/src/test/java/org/geoserver/security/filter/GeoServerRequestHeaderAuthenticationFilterTest.java b/src/main/src/test/java/org/geoserver/security/filter/GeoServerRequestHeaderAuthenticationFilterTest.java new file mode 100644 index 00000000000..1acdde7ae07 --- /dev/null +++ b/src/main/src/test/java/org/geoserver/security/filter/GeoServerRequestHeaderAuthenticationFilterTest.java @@ -0,0 +1,71 @@ +/* (c) 2023 Open Source Geospatial Foundation - all rights reserved + * This code is licensed under the GPL 2.0 license, available at the root + * application directory. + */ +package org.geoserver.security.filter; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; + +import java.io.File; +import javax.servlet.http.HttpServletResponse; +import org.geoserver.config.GeoServerDataDirectory; +import org.geoserver.security.GeoServerSecurityManager; +import org.geoserver.security.config.PreAuthenticatedUserNameFilterConfig; +import org.junit.Test; +import org.springframework.mock.web.MockFilterChain; +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.mock.web.MockHttpServletResponse; +import org.springframework.security.core.context.SecurityContext; +import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.core.context.SecurityContextImpl; +import org.springframework.security.web.authentication.preauth.PreAuthenticatedAuthenticationToken; + +public class GeoServerRequestHeaderAuthenticationFilterTest { + + @Test + public void testAuthenticationViaPreAuthChanging() throws Exception { + MockHttpServletRequest request = new MockHttpServletRequest(); + HttpServletResponse response = new MockHttpServletResponse(); + MockFilterChain filterChain = new MockFilterChain(); + SecurityContext sc = new SecurityContextImpl(); + sc.setAuthentication(new PreAuthenticatedAuthenticationToken("testadmin", null)); + SecurityContextHolder.setContext(sc); + GeoServerRequestHeaderAuthenticationFilter toTest = + new GeoServerRequestHeaderAuthenticationFilter(); + toTest.setPrincipalHeaderAttribute("sec-username"); + request.addHeader("sec-username", "testuser"); + toTest.setSecurityManager( + new GeoServerSecurityManager(new GeoServerDataDirectory(new File("/tmp")))); + toTest.setRoleSource( + PreAuthenticatedUserNameFilterConfig.PreAuthenticatedUserNameRoleSource.Header); + + toTest.doFilter(request, response, filterChain); + + assertEquals( + "testuser", + SecurityContextHolder.getContext().getAuthentication().getPrincipal().toString()); + } + + @Test + public void testAuthenticationViaPreAuthNoHeader() throws Exception { + MockHttpServletRequest request = new MockHttpServletRequest(); + HttpServletResponse response = new MockHttpServletResponse(); + MockFilterChain filterChain = new MockFilterChain(); + SecurityContext sc = new SecurityContextImpl(); + sc.setAuthentication(new PreAuthenticatedAuthenticationToken("testadmin", null)); + SecurityContextHolder.setContext(sc); + GeoServerRequestHeaderAuthenticationFilter toTest = + new GeoServerRequestHeaderAuthenticationFilter(); + toTest.setPrincipalHeaderAttribute("sec-username"); + toTest.setSecurityManager( + new GeoServerSecurityManager(new GeoServerDataDirectory(new File("/tmp")))); + toTest.setRoleSource( + PreAuthenticatedUserNameFilterConfig.PreAuthenticatedUserNameRoleSource.Header); + + toTest.doFilter(request, response, filterChain); + + // The security context should have been cleared + assertNull(SecurityContextHolder.getContext().getAuthentication()); + } +}