package com.zy.common.utils; import com.alibaba.fastjson.JSON; import com.alibaba.fastjson.JSONObject; import com.core.common.Cools; import jakarta.servlet.http.HttpServletRequest; import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; import java.security.GeneralSecurityException; import java.security.KeyFactory; import java.security.MessageDigest; import java.security.PublicKey; import java.security.Signature; import java.security.spec.MGF1ParameterSpec; import java.security.spec.PSSParameterSpec; import java.security.spec.X509EncodedKeySpec; import java.util.ArrayList; import java.util.Base64; import java.util.List; import java.util.Locale; public final class PasskeyWebAuthnUtil { private static final Base64.Decoder URL_DECODER = Base64.getUrlDecoder(); private PasskeyWebAuthnUtil() { } public static JSONObject parseClientData(String clientDataJsonBase64Url) { if (Cools.isEmpty(clientDataJsonBase64Url)) { throw new IllegalArgumentException("Missing clientDataJSON"); } String json = new String(decodeBase64Url(clientDataJsonBase64Url), StandardCharsets.UTF_8); JSONObject clientData = JSON.parseObject(json); if (clientData == null) { throw new IllegalArgumentException("Invalid clientDataJSON"); } return clientData; } public static void validateClientData(JSONObject clientData, String expectedType, String expectedChallenge, String expectedOrigin) { if (clientData == null) { throw new IllegalArgumentException("Missing clientData"); } if (!Cools.eq(expectedType, clientData.getString("type"))) { throw new IllegalArgumentException("Unexpected WebAuthn type"); } if (!Cools.eq(expectedChallenge, clientData.getString("challenge"))) { throw new IllegalArgumentException("Challenge mismatch"); } if (!Cools.eq(expectedOrigin, clientData.getString("origin"))) { throw new IllegalArgumentException("Origin mismatch"); } } public static AuthenticatorData validateAuthenticatorData(String authenticatorDataBase64Url, String rpId, boolean requireUserVerification) throws GeneralSecurityException { byte[] authenticatorData = decodeBase64Url(authenticatorDataBase64Url); if (authenticatorData.length < 37) { throw new GeneralSecurityException("Invalid authenticator data"); } byte[] expectedRpIdHash = sha256(rpId.getBytes(StandardCharsets.UTF_8)); for (int i = 0; i < expectedRpIdHash.length; i++) { if (authenticatorData[i] != expectedRpIdHash[i]) { throw new GeneralSecurityException("RP ID hash mismatch"); } } int flags = authenticatorData[32] & 0xFF; if ((flags & 0x01) == 0) { throw new GeneralSecurityException("User presence required"); } if (requireUserVerification && (flags & 0x04) == 0) { throw new GeneralSecurityException("User verification required"); } long signCount = ByteBuffer.wrap(authenticatorData, 33, 4).getInt() & 0xFFFFFFFFL; return new AuthenticatorData(authenticatorData, flags, signCount); } public static void verifyAssertionSignature(String publicKeyBase64Url, Integer algorithm, String authenticatorDataBase64Url, String clientDataJsonBase64Url, String signatureBase64Url) throws GeneralSecurityException { PublicKey publicKey = readPublicKey(publicKeyBase64Url, algorithm); Signature verifier = createSignatureVerifier(publicKey, algorithm); verifier.initVerify(publicKey); verifier.update(decodeBase64Url(authenticatorDataBase64Url)); verifier.update(sha256(decodeBase64Url(clientDataJsonBase64Url))); if (!verifier.verify(decodeBase64Url(signatureBase64Url))) { throw new GeneralSecurityException("Invalid passkey signature"); } } public static void ensurePublicKeyMaterial(String publicKeyBase64Url, Integer algorithm) throws GeneralSecurityException { readPublicKey(publicKeyBase64Url, algorithm); } public static byte[] decodeBase64Url(String value) { if (Cools.isEmpty(value)) { throw new IllegalArgumentException("Missing base64Url value"); } return URL_DECODER.decode(String.valueOf(value).trim()); } public static String buildOrigin(HttpServletRequest request) { String scheme = normalizeForwardedValue(request.getHeader("X-Forwarded-Proto")); if (Cools.isEmpty(scheme)) { scheme = request.getScheme(); } String host = resolveHost(request); return scheme.toLowerCase(Locale.ROOT) + "://" + host; } public static String buildRpId(HttpServletRequest request) { String host = resolveHost(request); if (host.startsWith("[")) { int bracket = host.indexOf(']'); return bracket > 0 ? host.substring(1, bracket).toLowerCase(Locale.ROOT) : host.toLowerCase(Locale.ROOT); } int colonIndex = host.indexOf(':'); if (colonIndex >= 0) { host = host.substring(0, colonIndex); } return host.toLowerCase(Locale.ROOT); } public static boolean isSecureOriginAllowed(String origin, String rpId) { if (Cools.isEmpty(origin) || Cools.isEmpty(rpId)) { return false; } String lowerOrigin = origin.toLowerCase(Locale.ROOT); String lowerRpId = rpId.toLowerCase(Locale.ROOT); if ("localhost".equals(lowerRpId) || "127.0.0.1".equals(lowerRpId) || "::1".equals(lowerRpId) || lowerRpId.endsWith(".localhost")) { return true; } return lowerOrigin.startsWith("https://"); } public static byte[] buildUserHandle(Long userId) { return String.valueOf(userId).getBytes(StandardCharsets.UTF_8); } private static PublicKey readPublicKey(String publicKeyBase64Url, Integer algorithm) throws GeneralSecurityException { byte[] encoded = decodeBase64Url(publicKeyBase64Url); X509EncodedKeySpec keySpec = new X509EncodedKeySpec(encoded); List keyFactories = keyFactoriesForAlgorithm(algorithm); GeneralSecurityException failure = null; for (String keyFactoryName : keyFactories) { try { return KeyFactory.getInstance(keyFactoryName).generatePublic(keySpec); } catch (GeneralSecurityException ex) { failure = ex; } } throw failure == null ? new GeneralSecurityException("Unsupported passkey algorithm") : failure; } private static Signature createSignatureVerifier(PublicKey publicKey, Integer algorithm) throws GeneralSecurityException { int value = algorithm == null ? Integer.MIN_VALUE : algorithm; switch (value) { case -7: return Signature.getInstance("SHA256withECDSA"); case -257: return Signature.getInstance("SHA256withRSA"); case -37: Signature pss = Signature.getInstance("RSASSA-PSS"); pss.setParameter(new PSSParameterSpec("SHA-256", "MGF1", MGF1ParameterSpec.SHA256, 32, 1)); return pss; case -8: return Signature.getInstance("Ed25519"); default: if ("EC".equalsIgnoreCase(publicKey.getAlgorithm())) { return Signature.getInstance("SHA256withECDSA"); } if ("RSA".equalsIgnoreCase(publicKey.getAlgorithm())) { return Signature.getInstance("SHA256withRSA"); } if ("Ed25519".equalsIgnoreCase(publicKey.getAlgorithm()) || "EdDSA".equalsIgnoreCase(publicKey.getAlgorithm())) { return Signature.getInstance("Ed25519"); } throw new GeneralSecurityException("Unsupported passkey signature algorithm"); } } private static List keyFactoriesForAlgorithm(Integer algorithm) { List result = new ArrayList<>(); int value = algorithm == null ? Integer.MIN_VALUE : algorithm; switch (value) { case -7: result.add("EC"); break; case -257: case -37: result.add("RSA"); break; case -8: result.add("Ed25519"); result.add("EdDSA"); break; default: result.add("EC"); result.add("RSA"); result.add("Ed25519"); result.add("EdDSA"); break; } return result; } private static String resolveHost(HttpServletRequest request) { String host = normalizeForwardedValue(request.getHeader("X-Forwarded-Host")); if (Cools.isEmpty(host)) { host = request.getServerName(); int port = request.getServerPort(); if (port > 0 && port != 80 && port != 443) { host = host + ":" + port; } } String port = normalizeForwardedValue(request.getHeader("X-Forwarded-Port")); if (!Cools.isEmpty(port) && host.indexOf(':') < 0 && !host.startsWith("[")) { host = host + ":" + port; } return host; } private static String normalizeForwardedValue(String value) { if (Cools.isEmpty(value)) { return null; } String normalized = String.valueOf(value).trim(); int commaIndex = normalized.indexOf(','); if (commaIndex >= 0) { normalized = normalized.substring(0, commaIndex).trim(); } return normalized; } private static byte[] sha256(byte[] data) throws GeneralSecurityException { return MessageDigest.getInstance("SHA-256").digest(data); } public static final class AuthenticatorData { private final byte[] raw; private final int flags; private final long signCount; private AuthenticatorData(byte[] raw, int flags, long signCount) { this.raw = raw; this.flags = flags; this.signCount = signCount; } public byte[] getRaw() { return raw; } public int getFlags() { return flags; } public long getSignCount() { return signCount; } } }