#
Junjie
昨天 be1cd9e5b30097ca427a9c2b7b054b28854e410a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
package com.zy.common.utils;
 
import com.core.common.Cools;
 
import javax.crypto.Mac;
import javax.crypto.spec.SecretKeySpec;
import java.io.ByteArrayOutputStream;
import java.net.URLEncoder;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.security.SecureRandom;
import java.util.Locale;
 
public final class MfaTotpUtil {
 
    private static final char[] BASE32_ALPHABET = "ABCDEFGHIJKLMNOPQRSTUVWXYZ234567".toCharArray();
    private static final SecureRandom SECURE_RANDOM = new SecureRandom();
    private static final int SECRET_SIZE = 20;
    private static final int OTP_DIGITS = 6;
    private static final int OTP_PERIOD_SECONDS = 30;
 
    private MfaTotpUtil() {
    }
 
    public static String generateSecret() {
        byte[] buffer = new byte[SECRET_SIZE];
        SECURE_RANDOM.nextBytes(buffer);
        return encodeBase32(buffer);
    }
 
    public static boolean verifyCode(String secret, String code, int window) {
        if (Cools.isEmpty(secret, code)) {
            return false;
        }
        String normalizedCode = String.valueOf(code).replaceAll("\\s+", "");
        if (!normalizedCode.matches("\\d{" + OTP_DIGITS + "}")) {
            return false;
        }
        try {
            long currentStep = System.currentTimeMillis() / 1000L / OTP_PERIOD_SECONDS;
            for (int offset = -window; offset <= window; offset++) {
                if (normalizedCode.equals(generateCode(secret, currentStep + offset))) {
                    return true;
                }
            }
        } catch (Exception ignored) {
        }
        return false;
    }
 
    public static String buildOtpAuthUri(String issuer, String account, String secret) {
        String safeIssuer = Cools.isEmpty(issuer) ? "WCS" : issuer.trim();
        String safeAccount = Cools.isEmpty(account) ? "user" : account.trim();
        String label = urlEncode(safeIssuer + ":" + safeAccount);
        return "otpauth://totp/" + label
                + "?secret=" + secret
                + "&issuer=" + urlEncode(safeIssuer)
                + "&algorithm=SHA1&digits=" + OTP_DIGITS
                + "&period=" + OTP_PERIOD_SECONDS;
    }
 
    public static String maskSecret(String secret) {
        if (Cools.isEmpty(secret)) {
            return "";
        }
        String value = String.valueOf(secret).trim();
        if (value.length() <= 8) {
            return value;
        }
        return value.substring(0, 4) + "****" + value.substring(value.length() - 4);
    }
 
    private static String generateCode(String secret, long step) {
        try {
            byte[] key = decodeBase32(secret);
            byte[] data = ByteBuffer.allocate(8).putLong(step).array();
            Mac mac = Mac.getInstance("HmacSHA1");
            mac.init(new SecretKeySpec(key, "HmacSHA1"));
            byte[] hash = mac.doFinal(data);
            int offset = hash[hash.length - 1] & 0x0F;
            int binary = ((hash[offset] & 0x7F) << 24)
                    | ((hash[offset + 1] & 0xFF) << 16)
                    | ((hash[offset + 2] & 0xFF) << 8)
                    | (hash[offset + 3] & 0xFF);
            int otp = binary % (int) Math.pow(10, OTP_DIGITS);
            return String.format(Locale.ROOT, "%0" + OTP_DIGITS + "d", otp);
        } catch (Exception e) {
            throw new IllegalStateException("generate totp code failed", e);
        }
    }
 
    private static String encodeBase32(byte[] data) {
        StringBuilder builder = new StringBuilder((data.length * 8 + 4) / 5);
        int buffer = 0;
        int bitsLeft = 0;
        for (byte datum : data) {
            buffer = (buffer << 8) | (datum & 0xFF);
            bitsLeft += 8;
            while (bitsLeft >= 5) {
                builder.append(BASE32_ALPHABET[(buffer >> (bitsLeft - 5)) & 0x1F]);
                bitsLeft -= 5;
            }
        }
        if (bitsLeft > 0) {
            builder.append(BASE32_ALPHABET[(buffer << (5 - bitsLeft)) & 0x1F]);
        }
        return builder.toString();
    }
 
    private static byte[] decodeBase32(String value) {
        String normalized = String.valueOf(value)
                .trim()
                .replace("=", "")
                .replace(" ", "")
                .replace("-", "")
                .toUpperCase(Locale.ROOT);
        ByteArrayOutputStream output = new ByteArrayOutputStream();
        int buffer = 0;
        int bitsLeft = 0;
        for (int i = 0; i < normalized.length(); i++) {
            char current = normalized.charAt(i);
            int index = indexOfBase32(current);
            if (index < 0) {
                throw new IllegalArgumentException("invalid base32 secret");
            }
            buffer = (buffer << 5) | index;
            bitsLeft += 5;
            if (bitsLeft >= 8) {
                output.write((buffer >> (bitsLeft - 8)) & 0xFF);
                bitsLeft -= 8;
            }
        }
        return output.toByteArray();
    }
 
    private static int indexOfBase32(char value) {
        for (int i = 0; i < BASE32_ALPHABET.length; i++) {
            if (BASE32_ALPHABET[i] == value) {
                return i;
            }
        }
        return -1;
    }
 
    private static String urlEncode(String value) {
        return URLEncoder.encode(value, StandardCharsets.UTF_8).replace("+", "%20");
    }
}