Files
bflows-bandi-be/src/main/java/net/gepafin/tendermanagement/config/SamlConfig.java
2024-10-20 02:37:53 +05:30

227 lines
11 KiB
Java

package net.gepafin.tendermanagement.config;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.security.KeyFactory;
import java.security.PrivateKey;
import java.security.cert.CertificateFactory;
import java.security.cert.X509Certificate;
import java.security.spec.PKCS8EncodedKeySpec;
import java.time.Instant;
import java.util.UUID;
import org.bouncycastle.util.io.pem.PemReader;
import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport;
import org.opensaml.saml.common.SAMLVersion;
import org.opensaml.saml.common.xml.SAMLConstants;
import org.opensaml.saml.saml2.core.AuthnContextClassRef;
import org.opensaml.saml.saml2.core.AuthnContextComparisonTypeEnumeration;
import org.opensaml.saml.saml2.core.AuthnRequest;
import org.opensaml.saml.saml2.core.RequestedAuthnContext;
import org.opensaml.saml.saml2.core.impl.AuthnContextClassRefBuilder;
import org.opensaml.saml.saml2.core.impl.RequestedAuthnContextBuilder;
import org.opensaml.security.x509.BasicX509Credential;
import org.opensaml.xmlsec.config.impl.DefaultSecurityConfigurationBootstrap;
import org.opensaml.xmlsec.signature.Signature;
import org.opensaml.xmlsec.signature.support.SignatureConstants;
import org.opensaml.xmlsec.signature.support.Signer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.security.saml2.core.Saml2X509Credential;
import org.springframework.security.saml2.provider.service.registration.InMemoryRelyingPartyRegistrationRepository;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding;
import org.springframework.security.saml2.provider.service.web.DefaultRelyingPartyRegistrationResolver;
import org.springframework.security.saml2.provider.service.web.RelyingPartyRegistrationResolver;
import org.springframework.security.saml2.provider.service.web.authentication.OpenSaml4AuthenticationRequestResolver;
import org.springframework.security.saml2.provider.service.web.authentication.Saml2AuthenticationRequestResolver;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;
import jakarta.servlet.http.HttpServletRequest;
import net.gepafin.tendermanagement.entities.SamlResponseEntity;
import net.gepafin.tendermanagement.enums.SamlResponseStatusEnum;
import net.gepafin.tendermanagement.repositories.SamlResponseRepository;
@Configuration
public class SamlConfig {
private final Logger logger = LoggerFactory.getLogger(SamlConfig.class);
@Value("${base-url}")
String baseUrl;
@Value("${spid.ipd.base.url}")
String ipdBaseUrl;
@Value("${active.profile.folder}")
String activeProfileFolder;
@Autowired
private SamlResponseRepository samlResponseRepository;
@Bean
public RelyingPartyRegistrationRepository relyingPartyRegistrationRepository() {
String entityId = baseUrl + "/v1/saml/gw/metadata";
String acsUrl = baseUrl + "/login/saml2/sso/loginumbria";
RelyingPartyRegistration registration = RelyingPartyRegistration.withRegistrationId("loginumbria")
.entityId(entityId)
.signingX509Credentials(credentials -> {
try {
credentials.add(Saml2X509Credential.signing(readPrivateKey(), readCertificate()));
} catch (Exception e) {
e.printStackTrace();
}
})
.assertionConsumerServiceLocation(acsUrl)
.assertingPartyDetails(details -> details.entityId(ipdBaseUrl + "/gw/metadata")
.singleSignOnServiceLocation(ipdBaseUrl + "/gw/SSOProxy/SAML2")
.singleSignOnServiceBinding(Saml2MessageBinding.POST)
.wantAuthnRequestsSigned(true)
.verificationX509Credentials(credentials -> {
try {
// Load the IDP's public certificate for verifying the SAML response signature
credentials.add(Saml2X509Credential.verification(readIdpCertificate()));
} catch (Exception e) {
e.printStackTrace();
}
})
)
.build();
return new InMemoryRelyingPartyRegistrationRepository(registration);
}
public AuthnRequest createSignedAuthnRequest(PrivateKey privateKey, X509Certificate certificate) throws Exception {
AuthnRequest authnRequest = (AuthnRequest) XMLObjectProviderRegistrySupport.getBuilderFactory()
.getBuilder(AuthnRequest.DEFAULT_ELEMENT_NAME)
.buildObject(AuthnRequest.DEFAULT_ELEMENT_NAME);
authnRequest.setID("_" + UUID.randomUUID().toString());
authnRequest.setVersion(SAMLVersion.VERSION_20);
// authnRequest.setIssueInstant(new DateTime());
authnRequest.setIssueInstant(Instant.now());
// Sign the AuthnRequest
// BasicCredential signingCredential = new BasicCredential(certificate, privateKey);
BasicX509Credential signingCredential = new BasicX509Credential(certificate, privateKey);
Signature signature = (Signature) XMLObjectProviderRegistrySupport.getBuilderFactory()
.getBuilder(Signature.DEFAULT_ELEMENT_NAME)
.buildObject(Signature.DEFAULT_ELEMENT_NAME);
signature.setCanonicalizationAlgorithm(SignatureConstants.ALGO_ID_C14N_EXCL_OMIT_COMMENTS);
signature.setSigningCredential(signingCredential);
signature.setSignatureAlgorithm(SignatureConstants.ALGO_ID_SIGNATURE_RSA_SHA1); // Set RSA-SHA1
authnRequest.setSignature(signature);
DefaultSecurityConfigurationBootstrap.buildDefaultSignatureSigningConfiguration();
// Marshall and sign the object
XMLObjectProviderRegistrySupport.getMarshallerFactory().getMarshaller(authnRequest).marshall(authnRequest);
Signer.signObject(signature);
return authnRequest;
}
@Bean
public Saml2AuthenticationRequestResolver authenticationRequestResolver(RelyingPartyRegistrationRepository registrations) {
RelyingPartyRegistrationResolver registrationResolver = new DefaultRelyingPartyRegistrationResolver(registrations);
OpenSaml4AuthenticationRequestResolver authenticationRequestResolver = new OpenSaml4AuthenticationRequestResolver(registrationResolver);
authenticationRequestResolver.setAuthnRequestCustomizer((context) -> {
HttpServletRequest request = ((ServletRequestAttributes) RequestContextHolder.currentRequestAttributes()).getRequest();
String hubUuid = (String) request.getAttribute("hubId");
logger.info("Hub id " + hubUuid);
String inResponseTo = "_" + UUID.randomUUID().toString();
// Continue with normal AuthnRequest configuration
AuthnRequest authnRequest = context.getAuthnRequest();
authnRequest.setID(inResponseTo);
authnRequest.setVersion(SAMLVersion.VERSION_20);
authnRequest.setProtocolBinding(SAMLConstants.SAML2_POST_BINDING_URI);
authnRequest.setRequestedAuthnContext(buildRequestedAuthnContext());
SamlResponseEntity samlResponse = new SamlResponseEntity();
samlResponse.setHubUuid(hubUuid);
samlResponse.setInResponseTo(inResponseTo);
samlResponse.setStatus(SamlResponseStatusEnum.INITIATED.getValue());
samlResponseRepository.save(samlResponse);
// Log the SAML AuthnRequest after setting context
String samlRequest = SamlRequestLogger.convertSAMLObjectToString(authnRequest);
logger.info("SAML AuthnRequest after setting context: " + samlRequest);
});
return authenticationRequestResolver;
}
private RequestedAuthnContext buildRequestedAuthnContext() {
AuthnContextClassRefBuilder authnContextClassRefBuilder = new AuthnContextClassRefBuilder();
AuthnContextClassRef authnContextClassRef = authnContextClassRefBuilder.buildObject(
SAMLConstants.SAML20_NS, AuthnContextClassRef.DEFAULT_ELEMENT_LOCAL_NAME, SAMLConstants.SAML20_PREFIX
);
// Set the SPID Level 2 authentication context
authnContextClassRef.setURI("urn:oasis:names:tc:SAML:2.0:ac:classes:SecureRemotePassword");
RequestedAuthnContextBuilder requestedAuthnContextBuilder = new RequestedAuthnContextBuilder();
RequestedAuthnContext requestedAuthnContext = requestedAuthnContextBuilder.buildObject();
requestedAuthnContext.setComparison(AuthnContextComparisonTypeEnumeration.EXACT);
requestedAuthnContext.getAuthnContextClassRefs().add(authnContextClassRef);
return requestedAuthnContext;
}
public PrivateKey readPrivateKey() throws Exception {
// Path to your private key PEM file
try (PemReader pemReader = new PemReader(new InputStreamReader(readKey(activeProfileFolder + "/saml/private-key.pem")))) {
// Read the PEM content
byte[] pemContent = pemReader.readPemObject().getContent();
// Decode the PEM content
PKCS8EncodedKeySpec keySpec = new PKCS8EncodedKeySpec(pemContent);
KeyFactory keyFactory = KeyFactory.getInstance("RSA"); // Use RSA algorithm
// Generate and return the PrivateKey
return keyFactory.generatePrivate(keySpec);
}
}
public X509Certificate readCertificate() throws Exception {
// Path to your certificate PEM fileFile
try (InputStream inStream = readKey(activeProfileFolder + "/saml/public-cert.pem")) {
CertificateFactory certFactory = CertificateFactory.getInstance("X.509");
return (X509Certificate) certFactory.generateCertificate(inStream);
}
}
public X509Certificate readIdpCertificate() throws Exception {
// Path to your IDP public certificate PEM file
try (InputStream inStream = readKey(activeProfileFolder + "/saml/idp-certificate.pem")) {
CertificateFactory certFactory = CertificateFactory.getInstance("X.509");
return (X509Certificate) certFactory.generateCertificate(inStream);
}
}
public InputStream readKey(String path) throws IOException {
ClassLoader classLoader = getClass().getClassLoader();
InputStream inputStream = classLoader.getResourceAsStream(path);
if (inputStream == null) {
throw new FileNotFoundException("file not found : "+path);
}
return inputStream;
}
}