package com.zy.common.utils;
|
|
import com.alibaba.fastjson.JSON;
|
import com.alibaba.fastjson.JSONObject;
|
import com.core.common.Cools;
|
|
import jakarta.servlet.http.HttpServletRequest;
|
import java.nio.ByteBuffer;
|
import java.nio.charset.StandardCharsets;
|
import java.security.GeneralSecurityException;
|
import java.security.KeyFactory;
|
import java.security.MessageDigest;
|
import java.security.PublicKey;
|
import java.security.Signature;
|
import java.security.spec.MGF1ParameterSpec;
|
import java.security.spec.PSSParameterSpec;
|
import java.security.spec.X509EncodedKeySpec;
|
import java.util.ArrayList;
|
import java.util.Base64;
|
import java.util.List;
|
import java.util.Locale;
|
|
public final class PasskeyWebAuthnUtil {
|
|
private static final Base64.Decoder URL_DECODER = Base64.getUrlDecoder();
|
|
private PasskeyWebAuthnUtil() {
|
}
|
|
public static JSONObject parseClientData(String clientDataJsonBase64Url) {
|
if (Cools.isEmpty(clientDataJsonBase64Url)) {
|
throw new IllegalArgumentException("Missing clientDataJSON");
|
}
|
String json = new String(decodeBase64Url(clientDataJsonBase64Url), StandardCharsets.UTF_8);
|
JSONObject clientData = JSON.parseObject(json);
|
if (clientData == null) {
|
throw new IllegalArgumentException("Invalid clientDataJSON");
|
}
|
return clientData;
|
}
|
|
public static void validateClientData(JSONObject clientData, String expectedType, String expectedChallenge, String expectedOrigin) {
|
if (clientData == null) {
|
throw new IllegalArgumentException("Missing clientData");
|
}
|
if (!Cools.eq(expectedType, clientData.getString("type"))) {
|
throw new IllegalArgumentException("Unexpected WebAuthn type");
|
}
|
if (!Cools.eq(expectedChallenge, clientData.getString("challenge"))) {
|
throw new IllegalArgumentException("Challenge mismatch");
|
}
|
if (!Cools.eq(expectedOrigin, clientData.getString("origin"))) {
|
throw new IllegalArgumentException("Origin mismatch");
|
}
|
}
|
|
public static AuthenticatorData validateAuthenticatorData(String authenticatorDataBase64Url, String rpId, boolean requireUserVerification) throws GeneralSecurityException {
|
byte[] authenticatorData = decodeBase64Url(authenticatorDataBase64Url);
|
if (authenticatorData.length < 37) {
|
throw new GeneralSecurityException("Invalid authenticator data");
|
}
|
byte[] expectedRpIdHash = sha256(rpId.getBytes(StandardCharsets.UTF_8));
|
for (int i = 0; i < expectedRpIdHash.length; i++) {
|
if (authenticatorData[i] != expectedRpIdHash[i]) {
|
throw new GeneralSecurityException("RP ID hash mismatch");
|
}
|
}
|
int flags = authenticatorData[32] & 0xFF;
|
if ((flags & 0x01) == 0) {
|
throw new GeneralSecurityException("User presence required");
|
}
|
if (requireUserVerification && (flags & 0x04) == 0) {
|
throw new GeneralSecurityException("User verification required");
|
}
|
long signCount = ByteBuffer.wrap(authenticatorData, 33, 4).getInt() & 0xFFFFFFFFL;
|
return new AuthenticatorData(authenticatorData, flags, signCount);
|
}
|
|
public static void verifyAssertionSignature(String publicKeyBase64Url,
|
Integer algorithm,
|
String authenticatorDataBase64Url,
|
String clientDataJsonBase64Url,
|
String signatureBase64Url) throws GeneralSecurityException {
|
PublicKey publicKey = readPublicKey(publicKeyBase64Url, algorithm);
|
Signature verifier = createSignatureVerifier(publicKey, algorithm);
|
verifier.initVerify(publicKey);
|
verifier.update(decodeBase64Url(authenticatorDataBase64Url));
|
verifier.update(sha256(decodeBase64Url(clientDataJsonBase64Url)));
|
if (!verifier.verify(decodeBase64Url(signatureBase64Url))) {
|
throw new GeneralSecurityException("Invalid passkey signature");
|
}
|
}
|
|
public static void ensurePublicKeyMaterial(String publicKeyBase64Url, Integer algorithm) throws GeneralSecurityException {
|
readPublicKey(publicKeyBase64Url, algorithm);
|
}
|
|
public static byte[] decodeBase64Url(String value) {
|
if (Cools.isEmpty(value)) {
|
throw new IllegalArgumentException("Missing base64Url value");
|
}
|
return URL_DECODER.decode(String.valueOf(value).trim());
|
}
|
|
public static String buildOrigin(HttpServletRequest request) {
|
String scheme = normalizeForwardedValue(request.getHeader("X-Forwarded-Proto"));
|
if (Cools.isEmpty(scheme)) {
|
scheme = request.getScheme();
|
}
|
String host = resolveHost(request);
|
return scheme.toLowerCase(Locale.ROOT) + "://" + host;
|
}
|
|
public static String buildRpId(HttpServletRequest request) {
|
String host = resolveHost(request);
|
if (host.startsWith("[")) {
|
int bracket = host.indexOf(']');
|
return bracket > 0 ? host.substring(1, bracket).toLowerCase(Locale.ROOT) : host.toLowerCase(Locale.ROOT);
|
}
|
int colonIndex = host.indexOf(':');
|
if (colonIndex >= 0) {
|
host = host.substring(0, colonIndex);
|
}
|
return host.toLowerCase(Locale.ROOT);
|
}
|
|
public static boolean isSecureOriginAllowed(String origin, String rpId) {
|
if (Cools.isEmpty(origin) || Cools.isEmpty(rpId)) {
|
return false;
|
}
|
String lowerOrigin = origin.toLowerCase(Locale.ROOT);
|
String lowerRpId = rpId.toLowerCase(Locale.ROOT);
|
if ("localhost".equals(lowerRpId)
|
|| "127.0.0.1".equals(lowerRpId)
|
|| "::1".equals(lowerRpId)
|
|| lowerRpId.endsWith(".localhost")) {
|
return true;
|
}
|
return lowerOrigin.startsWith("https://");
|
}
|
|
public static byte[] buildUserHandle(Long userId) {
|
return String.valueOf(userId).getBytes(StandardCharsets.UTF_8);
|
}
|
|
private static PublicKey readPublicKey(String publicKeyBase64Url, Integer algorithm) throws GeneralSecurityException {
|
byte[] encoded = decodeBase64Url(publicKeyBase64Url);
|
X509EncodedKeySpec keySpec = new X509EncodedKeySpec(encoded);
|
List<String> keyFactories = keyFactoriesForAlgorithm(algorithm);
|
GeneralSecurityException failure = null;
|
for (String keyFactoryName : keyFactories) {
|
try {
|
return KeyFactory.getInstance(keyFactoryName).generatePublic(keySpec);
|
} catch (GeneralSecurityException ex) {
|
failure = ex;
|
}
|
}
|
throw failure == null ? new GeneralSecurityException("Unsupported passkey algorithm") : failure;
|
}
|
|
private static Signature createSignatureVerifier(PublicKey publicKey, Integer algorithm) throws GeneralSecurityException {
|
int value = algorithm == null ? Integer.MIN_VALUE : algorithm;
|
switch (value) {
|
case -7:
|
return Signature.getInstance("SHA256withECDSA");
|
case -257:
|
return Signature.getInstance("SHA256withRSA");
|
case -37:
|
Signature pss = Signature.getInstance("RSASSA-PSS");
|
pss.setParameter(new PSSParameterSpec("SHA-256", "MGF1", MGF1ParameterSpec.SHA256, 32, 1));
|
return pss;
|
case -8:
|
return Signature.getInstance("Ed25519");
|
default:
|
if ("EC".equalsIgnoreCase(publicKey.getAlgorithm())) {
|
return Signature.getInstance("SHA256withECDSA");
|
}
|
if ("RSA".equalsIgnoreCase(publicKey.getAlgorithm())) {
|
return Signature.getInstance("SHA256withRSA");
|
}
|
if ("Ed25519".equalsIgnoreCase(publicKey.getAlgorithm()) || "EdDSA".equalsIgnoreCase(publicKey.getAlgorithm())) {
|
return Signature.getInstance("Ed25519");
|
}
|
throw new GeneralSecurityException("Unsupported passkey signature algorithm");
|
}
|
}
|
|
private static List<String> keyFactoriesForAlgorithm(Integer algorithm) {
|
List<String> result = new ArrayList<>();
|
int value = algorithm == null ? Integer.MIN_VALUE : algorithm;
|
switch (value) {
|
case -7:
|
result.add("EC");
|
break;
|
case -257:
|
case -37:
|
result.add("RSA");
|
break;
|
case -8:
|
result.add("Ed25519");
|
result.add("EdDSA");
|
break;
|
default:
|
result.add("EC");
|
result.add("RSA");
|
result.add("Ed25519");
|
result.add("EdDSA");
|
break;
|
}
|
return result;
|
}
|
|
private static String resolveHost(HttpServletRequest request) {
|
String host = normalizeForwardedValue(request.getHeader("X-Forwarded-Host"));
|
if (Cools.isEmpty(host)) {
|
host = request.getServerName();
|
int port = request.getServerPort();
|
if (port > 0 && port != 80 && port != 443) {
|
host = host + ":" + port;
|
}
|
}
|
String port = normalizeForwardedValue(request.getHeader("X-Forwarded-Port"));
|
if (!Cools.isEmpty(port) && host.indexOf(':') < 0 && !host.startsWith("[")) {
|
host = host + ":" + port;
|
}
|
return host;
|
}
|
|
private static String normalizeForwardedValue(String value) {
|
if (Cools.isEmpty(value)) {
|
return null;
|
}
|
String normalized = String.valueOf(value).trim();
|
int commaIndex = normalized.indexOf(',');
|
if (commaIndex >= 0) {
|
normalized = normalized.substring(0, commaIndex).trim();
|
}
|
return normalized;
|
}
|
|
private static byte[] sha256(byte[] data) throws GeneralSecurityException {
|
return MessageDigest.getInstance("SHA-256").digest(data);
|
}
|
|
public static final class AuthenticatorData {
|
private final byte[] raw;
|
private final int flags;
|
private final long signCount;
|
|
private AuthenticatorData(byte[] raw, int flags, long signCount) {
|
this.raw = raw;
|
this.flags = flags;
|
this.signCount = signCount;
|
}
|
|
public byte[] getRaw() {
|
return raw;
|
}
|
|
public int getFlags() {
|
return flags;
|
}
|
|
public long getSignCount() {
|
return signCount;
|
}
|
}
|
}
|