Implement support for kyber pre keys

This commit is contained in:
AsamK 2023-06-17 21:18:24 +02:00
parent 4e5c859aab
commit 306e38c9ee
11 changed files with 504 additions and 45 deletions

View file

@ -46,7 +46,7 @@
},
{
"name":"org.asamk.signal.manager.storage.protocol.SignalProtocolStore",
"methods":[{"name":"getIdentity","parameterTypes":["org.signal.libsignal.protocol.SignalProtocolAddress"] }, {"name":"getIdentityKeyPair","parameterTypes":[] }, {"name":"getLocalRegistrationId","parameterTypes":[] }, {"name":"isTrustedIdentity","parameterTypes":["org.signal.libsignal.protocol.SignalProtocolAddress","org.signal.libsignal.protocol.IdentityKey","org.signal.libsignal.protocol.state.IdentityKeyStore$Direction"] }, {"name":"loadPreKey","parameterTypes":["int"] }, {"name":"loadSenderKey","parameterTypes":["org.signal.libsignal.protocol.SignalProtocolAddress","java.util.UUID"] }, {"name":"loadSession","parameterTypes":["org.signal.libsignal.protocol.SignalProtocolAddress"] }, {"name":"loadSignedPreKey","parameterTypes":["int"] }, {"name":"removePreKey","parameterTypes":["int"] }, {"name":"saveIdentity","parameterTypes":["org.signal.libsignal.protocol.SignalProtocolAddress","org.signal.libsignal.protocol.IdentityKey"] }, {"name":"storeSenderKey","parameterTypes":["org.signal.libsignal.protocol.SignalProtocolAddress","java.util.UUID","org.signal.libsignal.protocol.groups.state.SenderKeyRecord"] }, {"name":"storeSession","parameterTypes":["org.signal.libsignal.protocol.SignalProtocolAddress","org.signal.libsignal.protocol.state.SessionRecord"] }]
"methods":[{"name":"getIdentity","parameterTypes":["org.signal.libsignal.protocol.SignalProtocolAddress"] }, {"name":"getIdentityKeyPair","parameterTypes":[] }, {"name":"getLocalRegistrationId","parameterTypes":[] }, {"name":"isTrustedIdentity","parameterTypes":["org.signal.libsignal.protocol.SignalProtocolAddress","org.signal.libsignal.protocol.IdentityKey","org.signal.libsignal.protocol.state.IdentityKeyStore$Direction"] }, {"name":"loadKyberPreKey","parameterTypes":["int"] }, {"name":"loadPreKey","parameterTypes":["int"] }, {"name":"loadSenderKey","parameterTypes":["org.signal.libsignal.protocol.SignalProtocolAddress","java.util.UUID"] }, {"name":"loadSession","parameterTypes":["org.signal.libsignal.protocol.SignalProtocolAddress"] }, {"name":"loadSignedPreKey","parameterTypes":["int"] }, {"name":"markKyberPreKeyUsed","parameterTypes":["int"] }, {"name":"removePreKey","parameterTypes":["int"] }, {"name":"saveIdentity","parameterTypes":["org.signal.libsignal.protocol.SignalProtocolAddress","org.signal.libsignal.protocol.IdentityKey"] }, {"name":"storeSenderKey","parameterTypes":["org.signal.libsignal.protocol.SignalProtocolAddress","java.util.UUID","org.signal.libsignal.protocol.groups.state.SenderKeyRecord"] }, {"name":"storeSession","parameterTypes":["org.signal.libsignal.protocol.SignalProtocolAddress","org.signal.libsignal.protocol.state.SessionRecord"] }]
},
{
"name":"org.asamk.signal.manager.storage.senderKeys.SenderKeyStore",
@ -133,6 +133,10 @@
"name":"org.signal.libsignal.protocol.state.IdentityKeyStore$Direction",
"fields":[{"name":"RECEIVING"}, {"name":"SENDING"}]
},
{
"name":"org.signal.libsignal.protocol.state.KyberPreKeyRecord",
"fields":[{"name":"unsafeHandle"}]
},
{
"name":"org.signal.libsignal.protocol.state.KyberPreKeyStore"
},

View file

@ -2179,16 +2179,25 @@
"name":"org.whispersystems.signalservice.internal.push.KyberPreKeyEntity",
"allDeclaredFields":true,
"queryAllDeclaredMethods":true,
"queryAllDeclaredConstructors":true
"queryAllDeclaredConstructors":true,
"methods":[{"name":"<init>","parameterTypes":[] }, {"name":"getKeyId","parameterTypes":[] }, {"name":"getPublicKey","parameterTypes":[] }, {"name":"getSignature","parameterTypes":[] }]
},
{
"name":"org.whispersystems.signalservice.internal.push.KyberPreKeyEntity$ByteArrayDeserializer",
"methods":[{"name":"<init>","parameterTypes":[] }]
},
{
"name":"org.whispersystems.signalservice.internal.push.KyberPreKeyEntity$ByteArraySerializer",
"methods":[{"name":"<init>","parameterTypes":[] }]
},
{
"name":"org.whispersystems.signalservice.internal.push.KyberPreKeyEntity$KEMPublicKeyDeserializer",
"methods":[{"name":"<init>","parameterTypes":[] }]
},
{
"name":"org.whispersystems.signalservice.internal.push.KyberPreKeyEntity$KEMPublicKeySerializer",
"methods":[{"name":"<init>","parameterTypes":[] }]
},
{
"name":"org.whispersystems.signalservice.internal.push.MismatchedDevices",
"allDeclaredFields":true,
@ -2250,7 +2259,8 @@
"name":"org.whispersystems.signalservice.internal.push.PreKeyState",
"allDeclaredFields":true,
"allDeclaredMethods":true,
"allDeclaredConstructors":true
"allDeclaredConstructors":true,
"methods":[{"name":"getIdentityKey","parameterTypes":[] }, {"name":"getPreKeys","parameterTypes":[] }, {"name":"getSignedPreKey","parameterTypes":[] }]
},
{
"name":"org.whispersystems.signalservice.internal.push.PreKeyStatus",

View file

@ -1,6 +1,7 @@
package org.asamk.signal.manager.config;
import org.asamk.signal.manager.api.ServiceEnvironment;
import org.signal.libsignal.protocol.util.Medium;
import org.whispersystems.signalservice.api.account.AccountAttributes;
import org.whispersystems.signalservice.api.push.TrustStore;
@ -15,8 +16,9 @@ import okhttp3.Interceptor;
public class ServiceConfig {
public final static int PREKEY_MINIMUM_COUNT = 20;
public final static int PREKEY_MINIMUM_COUNT = 10;
public final static int PREKEY_BATCH_SIZE = 100;
public final static int PREKEY_MAXIMUM_ID = Medium.MAX_VALUE;
public final static int MAX_ATTACHMENT_SIZE = 150 * 1024 * 1024;
public final static long MAX_ENVELOPE_SIZE = 0;
public final static long AVATAR_DOWNLOAD_FAILSAFE_MAX_SIZE = 10 * 1024 * 1024;

View file

@ -5,6 +5,7 @@ import org.asamk.signal.manager.internal.SignalDependencies;
import org.asamk.signal.manager.storage.SignalAccount;
import org.asamk.signal.manager.util.KeyUtils;
import org.signal.libsignal.protocol.IdentityKeyPair;
import org.signal.libsignal.protocol.state.KyberPreKeyRecord;
import org.signal.libsignal.protocol.state.PreKeyRecord;
import org.signal.libsignal.protocol.state.SignedPreKeyRecord;
import org.slf4j.Logger;
@ -39,7 +40,9 @@ public class PreKeyHelper {
if (preKeyCounts.getEcCount() < ServiceConfig.PREKEY_MINIMUM_COUNT) {
refreshPreKeys(serviceIdType);
}
// TODO kyber pre keys
if (preKeyCounts.getKyberCount() < ServiceConfig.PREKEY_MINIMUM_COUNT) {
refreshKyberPreKeys(serviceIdType);
}
}
public void refreshPreKeys() throws IOException {
@ -47,7 +50,7 @@ public class PreKeyHelper {
refreshPreKeys(ServiceIdType.PNI);
}
public void refreshPreKeys(ServiceIdType serviceIdType) throws IOException {
private void refreshPreKeys(ServiceIdType serviceIdType) throws IOException {
final var identityKeyPair = account.getIdentityKeyPair(serviceIdType);
if (identityKeyPair == null) {
return;
@ -97,4 +100,61 @@ public class PreKeyHelper {
return record;
}
private void refreshKyberPreKeys(ServiceIdType serviceIdType) throws IOException {
final var identityKeyPair = account.getIdentityKeyPair(serviceIdType);
if (identityKeyPair == null) {
return;
}
final var accountId = account.getAccountId(serviceIdType);
if (accountId == null) {
return;
}
try {
refreshKyberPreKeys(serviceIdType, identityKeyPair);
} catch (Exception e) {
logger.warn("Failed to store new pre keys, resetting preKey id offset", e);
account.resetKyberPreKeyOffsets(serviceIdType);
refreshKyberPreKeys(serviceIdType, identityKeyPair);
}
}
private void refreshKyberPreKeys(
final ServiceIdType serviceIdType, final IdentityKeyPair identityKeyPair
) throws IOException {
final var oneTimePreKeys = generateKyberPreKeys(serviceIdType, identityKeyPair);
final var lastResortPreKeyRecord = generateLastResortKyberPreKey(serviceIdType, identityKeyPair);
final var preKeyUpload = new PreKeyUpload(serviceIdType,
identityKeyPair.getPublicKey(),
null,
null,
lastResortPreKeyRecord,
oneTimePreKeys);
dependencies.getAccountManager().setPreKeys(preKeyUpload);
}
private List<KyberPreKeyRecord> generateKyberPreKeys(
ServiceIdType serviceIdType, final IdentityKeyPair identityKeyPair
) {
final var offset = account.getKyberPreKeyIdOffset(serviceIdType);
var records = KeyUtils.generateKyberPreKeyRecords(offset,
ServiceConfig.PREKEY_BATCH_SIZE,
identityKeyPair.getPrivateKey());
account.addKyberPreKeys(serviceIdType, records);
return records;
}
private KyberPreKeyRecord generateLastResortKyberPreKey(
ServiceIdType serviceIdType, IdentityKeyPair identityKeyPair
) {
final var signedPreKeyId = account.getKyberPreKeyIdOffset(serviceIdType);
var record = KeyUtils.generateKyberPreKeyRecord(signedPreKeyId, identityKeyPair.getPrivateKey());
account.addLastResortKyberPreKey(serviceIdType, record);
return record;
}
}

View file

@ -4,6 +4,7 @@ import com.zaxxer.hikari.HikariDataSource;
import org.asamk.signal.manager.storage.groups.GroupStore;
import org.asamk.signal.manager.storage.identities.IdentityKeyStore;
import org.asamk.signal.manager.storage.prekeys.KyberPreKeyStore;
import org.asamk.signal.manager.storage.prekeys.PreKeyStore;
import org.asamk.signal.manager.storage.prekeys.SignedPreKeyStore;
import org.asamk.signal.manager.storage.recipients.RecipientStore;
@ -23,7 +24,7 @@ import java.sql.SQLException;
public class AccountDatabase extends Database {
private final static Logger logger = LoggerFactory.getLogger(AccountDatabase.class);
private static final long DATABASE_VERSION = 13;
private static final long DATABASE_VERSION = 14;
private AccountDatabase(final HikariDataSource dataSource) {
super(logger, DATABASE_VERSION, dataSource);
@ -40,6 +41,7 @@ public class AccountDatabase extends Database {
StickerStore.createSql(connection);
PreKeyStore.createSql(connection);
SignedPreKeyStore.createSql(connection);
KyberPreKeyStore.createSql(connection);
GroupStore.createSql(connection);
SessionStore.createSql(connection);
IdentityKeyStore.createSql(connection);
@ -328,5 +330,23 @@ public class AccountDatabase extends Database {
}
}
}
if (oldVersion < 14) {
logger.debug("Updating database: Creating kyber_pre_key table");
{
try (final var statement = connection.createStatement()) {
statement.executeUpdate("""
CREATE TABLE kyber_pre_key (
_id INTEGER PRIMARY KEY,
account_id_type INTEGER NOT NULL,
key_id INTEGER NOT NULL,
serialized BLOB NOT NULL,
is_last_resort INTEGER NOT NULL,
UNIQUE(account_id_type, key_id)
) STRICT;
""");
}
}
}
}
}

View file

@ -21,6 +21,7 @@ import org.asamk.signal.manager.storage.identities.IdentityKeyStore;
import org.asamk.signal.manager.storage.identities.LegacyIdentityKeyStore;
import org.asamk.signal.manager.storage.identities.SignalIdentityKeyStore;
import org.asamk.signal.manager.storage.messageCache.MessageCache;
import org.asamk.signal.manager.storage.prekeys.KyberPreKeyStore;
import org.asamk.signal.manager.storage.prekeys.LegacyPreKeyStore;
import org.asamk.signal.manager.storage.prekeys.LegacySignedPreKeyStore;
import org.asamk.signal.manager.storage.prekeys.PreKeyStore;
@ -51,11 +52,11 @@ import org.asamk.signal.manager.util.KeyUtils;
import org.signal.libsignal.protocol.IdentityKeyPair;
import org.signal.libsignal.protocol.InvalidMessageException;
import org.signal.libsignal.protocol.SignalProtocolAddress;
import org.signal.libsignal.protocol.state.KyberPreKeyRecord;
import org.signal.libsignal.protocol.state.PreKeyRecord;
import org.signal.libsignal.protocol.state.SessionRecord;
import org.signal.libsignal.protocol.state.SignedPreKeyRecord;
import org.signal.libsignal.protocol.util.KeyHelper;
import org.signal.libsignal.protocol.util.Medium;
import org.signal.libsignal.zkgroup.InvalidInputException;
import org.signal.libsignal.zkgroup.profiles.ProfileKey;
import org.slf4j.Logger;
@ -98,6 +99,7 @@ import java.util.List;
import java.util.Optional;
import java.util.function.Supplier;
import static org.asamk.signal.manager.config.ServiceConfig.PREKEY_MAXIMUM_ID;
import static org.asamk.signal.manager.config.ServiceConfig.getCapabilities;
public class SignalAccount implements Closeable {
@ -105,7 +107,7 @@ public class SignalAccount implements Closeable {
private final static Logger logger = LoggerFactory.getLogger(SignalAccount.class);
private static final int MINIMUM_STORAGE_VERSION = 1;
private static final int CURRENT_STORAGE_VERSION = 6;
private static final int CURRENT_STORAGE_VERSION = 7;
private final Object LOCK = new Object();
@ -138,6 +140,10 @@ public class SignalAccount implements Closeable {
private int aciNextSignedPreKeyId = 1;
private int pniPreKeyIdOffset = 1;
private int pniNextSignedPreKeyId = 1;
private int aciKyberPreKeyIdOffset = 1;
private int aciActiveLastResortKyberPreKeyId = -1;
private int pniKyberPreKeyIdOffset = 1;
private int pniActiveLastResortKyberPreKeyId = -1;
private IdentityKeyPair aciIdentityKeyPair;
private IdentityKeyPair pniIdentityKeyPair;
private int localRegistrationId;
@ -151,8 +157,10 @@ public class SignalAccount implements Closeable {
private SignalProtocolStore pniSignalProtocolStore;
private PreKeyStore aciPreKeyStore;
private SignedPreKeyStore aciSignedPreKeyStore;
private KyberPreKeyStore aciKyberPreKeyStore;
private PreKeyStore pniPreKeyStore;
private SignedPreKeyStore pniSignedPreKeyStore;
private KyberPreKeyStore pniKyberPreKeyStore;
private SessionStore aciSessionStore;
private SessionStore pniSessionStore;
private IdentityKeyStore identityKeyStore;
@ -302,10 +310,14 @@ public class SignalAccount implements Closeable {
private void clearAllPreKeys() {
resetPreKeyOffsets(ServiceIdType.ACI);
resetPreKeyOffsets(ServiceIdType.PNI);
resetKyberPreKeyOffsets(ServiceIdType.ACI);
resetKyberPreKeyOffsets(ServiceIdType.PNI);
this.getAciPreKeyStore().removeAllPreKeys();
this.getAciSignedPreKeyStore().removeAllSignedPreKeys();
this.getAciKyberPreKeyStore().removeAllKyberPreKeys();
this.getPniPreKeyStore().removeAllPreKeys();
this.getPniSignedPreKeyStore().removeAllSignedPreKeys();
this.getPniKyberPreKeyStore().removeAllKyberPreKeys();
save();
}
@ -614,22 +626,42 @@ public class SignalAccount implements Closeable {
if (rootNode.hasNonNull("preKeyIdOffset")) {
aciPreKeyIdOffset = rootNode.get("preKeyIdOffset").asInt(1);
} else {
aciPreKeyIdOffset = 1;
aciPreKeyIdOffset = getRandomPreKeyIdOffset();
}
if (rootNode.hasNonNull("nextSignedPreKeyId")) {
aciNextSignedPreKeyId = rootNode.get("nextSignedPreKeyId").asInt(1);
} else {
aciNextSignedPreKeyId = 1;
aciNextSignedPreKeyId = getRandomPreKeyIdOffset();
}
if (rootNode.hasNonNull("pniPreKeyIdOffset")) {
pniPreKeyIdOffset = rootNode.get("pniPreKeyIdOffset").asInt(1);
} else {
pniPreKeyIdOffset = 1;
pniPreKeyIdOffset = getRandomPreKeyIdOffset();
}
if (rootNode.hasNonNull("pniNextSignedPreKeyId")) {
pniNextSignedPreKeyId = rootNode.get("pniNextSignedPreKeyId").asInt(1);
} else {
pniNextSignedPreKeyId = 1;
pniNextSignedPreKeyId = getRandomPreKeyIdOffset();
}
if (rootNode.hasNonNull("kyberPreKeyIdOffset")) {
aciKyberPreKeyIdOffset = rootNode.get("kyberPreKeyIdOffset").asInt(1);
} else {
aciKyberPreKeyIdOffset = getRandomPreKeyIdOffset();
}
if (rootNode.hasNonNull("activeLastResortKyberPreKeyId")) {
aciActiveLastResortKyberPreKeyId = rootNode.get("activeLastResortKyberPreKeyId").asInt(-1);
} else {
aciActiveLastResortKyberPreKeyId = -1;
}
if (rootNode.hasNonNull("pniKyberPreKeyIdOffset")) {
pniKyberPreKeyIdOffset = rootNode.get("pniKyberPreKeyIdOffset").asInt(1);
} else {
pniKyberPreKeyIdOffset = getRandomPreKeyIdOffset();
}
if (rootNode.hasNonNull("pniActiveLastResortKyberPreKeyId")) {
pniActiveLastResortKyberPreKeyId = rootNode.get("pniActiveLastResortKyberPreKeyId").asInt(-1);
} else {
pniActiveLastResortKyberPreKeyId = -1;
}
if (rootNode.hasNonNull("profileKey")) {
try {
@ -974,6 +1006,10 @@ public class SignalAccount implements Closeable {
.put("nextSignedPreKeyId", aciNextSignedPreKeyId)
.put("pniPreKeyIdOffset", pniPreKeyIdOffset)
.put("pniNextSignedPreKeyId", pniNextSignedPreKeyId)
.put("kyberPreKeyIdOffset", aciKyberPreKeyIdOffset)
.put("activeLastResortKyberPreKeyId", aciActiveLastResortKyberPreKeyId)
.put("pniKyberPreKeyIdOffset", pniKyberPreKeyIdOffset)
.put("pniActiveLastResortKyberPreKeyId", pniActiveLastResortKyberPreKeyId)
.put("profileKey",
profileKey == null ? null : Base64.getEncoder().encodeToString(profileKey.serialize()))
.put("registered", registered)
@ -1019,15 +1055,19 @@ public class SignalAccount implements Closeable {
public void resetPreKeyOffsets(final ServiceIdType serviceIdType) {
if (serviceIdType.equals(ServiceIdType.ACI)) {
this.aciPreKeyIdOffset = new SecureRandom().nextInt(Medium.MAX_VALUE);
this.aciNextSignedPreKeyId = new SecureRandom().nextInt(Medium.MAX_VALUE);
this.aciPreKeyIdOffset = getRandomPreKeyIdOffset();
this.aciNextSignedPreKeyId = getRandomPreKeyIdOffset();
} else {
this.pniPreKeyIdOffset = new SecureRandom().nextInt(Medium.MAX_VALUE);
this.pniNextSignedPreKeyId = new SecureRandom().nextInt(Medium.MAX_VALUE);
this.pniPreKeyIdOffset = getRandomPreKeyIdOffset();
this.pniNextSignedPreKeyId = getRandomPreKeyIdOffset();
}
save();
}
private static int getRandomPreKeyIdOffset() {
return new SecureRandom().nextInt(PREKEY_MAXIMUM_ID);
}
public void addPreKeys(ServiceIdType serviceIdType, List<PreKeyRecord> records) {
if (serviceIdType.equals(ServiceIdType.ACI)) {
addAciPreKeys(records);
@ -1036,26 +1076,26 @@ public class SignalAccount implements Closeable {
}
}
public void addAciPreKeys(List<PreKeyRecord> records) {
private void addAciPreKeys(List<PreKeyRecord> records) {
for (var record : records) {
if (aciPreKeyIdOffset != record.getId()) {
logger.error("Invalid pre key id {}, expected {}", record.getId(), aciPreKeyIdOffset);
throw new AssertionError("Invalid pre key id");
}
getAciPreKeyStore().storePreKey(record.getId(), record);
aciPreKeyIdOffset = (aciPreKeyIdOffset + 1) % Medium.MAX_VALUE;
aciPreKeyIdOffset = (aciPreKeyIdOffset + 1) % PREKEY_MAXIMUM_ID;
}
save();
}
public void addPniPreKeys(List<PreKeyRecord> records) {
private void addPniPreKeys(List<PreKeyRecord> records) {
for (var record : records) {
if (pniPreKeyIdOffset != record.getId()) {
logger.error("Invalid pre key id {}, expected {}", record.getId(), pniPreKeyIdOffset);
throw new AssertionError("Invalid pre key id");
}
getPniPreKeyStore().storePreKey(record.getId(), record);
pniPreKeyIdOffset = (pniPreKeyIdOffset + 1) % Medium.MAX_VALUE;
pniPreKeyIdOffset = (pniPreKeyIdOffset + 1) % PREKEY_MAXIMUM_ID;
}
save();
}
@ -1074,7 +1114,7 @@ public class SignalAccount implements Closeable {
throw new AssertionError("Invalid signed pre key id");
}
getAciSignedPreKeyStore().storeSignedPreKey(record.getId(), record);
aciNextSignedPreKeyId = (aciNextSignedPreKeyId + 1) % Medium.MAX_VALUE;
aciNextSignedPreKeyId = (aciNextSignedPreKeyId + 1) % PREKEY_MAXIMUM_ID;
save();
}
@ -1084,7 +1124,84 @@ public class SignalAccount implements Closeable {
throw new AssertionError("Invalid signed pre key id");
}
getPniSignedPreKeyStore().storeSignedPreKey(record.getId(), record);
pniNextSignedPreKeyId = (pniNextSignedPreKeyId + 1) % Medium.MAX_VALUE;
pniNextSignedPreKeyId = (pniNextSignedPreKeyId + 1) % PREKEY_MAXIMUM_ID;
save();
}
public void resetKyberPreKeyOffsets(final ServiceIdType serviceIdType) {
if (serviceIdType.equals(ServiceIdType.ACI)) {
this.aciKyberPreKeyIdOffset = getRandomPreKeyIdOffset();
this.aciActiveLastResortKyberPreKeyId = -1;
} else {
this.pniKyberPreKeyIdOffset = getRandomPreKeyIdOffset();
this.pniActiveLastResortKyberPreKeyId = -1;
}
save();
}
public void addKyberPreKeys(ServiceIdType serviceIdType, List<KyberPreKeyRecord> records) {
if (serviceIdType.equals(ServiceIdType.ACI)) {
addAciKyberPreKeys(records);
} else {
addPniKyberPreKeys(records);
}
}
private void addAciKyberPreKeys(List<KyberPreKeyRecord> records) {
for (var record : records) {
if (aciKyberPreKeyIdOffset != record.getId()) {
logger.error("Invalid kyber pre key id {}, expected {}", record.getId(), aciKyberPreKeyIdOffset);
throw new AssertionError("Invalid kyber pre key id");
}
getAciKyberPreKeyStore().storeKyberPreKey(record.getId(), record);
aciKyberPreKeyIdOffset = (aciKyberPreKeyIdOffset + 1) % PREKEY_MAXIMUM_ID;
}
save();
}
private void addPniKyberPreKeys(List<KyberPreKeyRecord> records) {
for (var record : records) {
if (pniKyberPreKeyIdOffset != record.getId()) {
logger.error("Invalid kyber pre key id {}, expected {}", record.getId(), pniKyberPreKeyIdOffset);
throw new AssertionError("Invalid kyber pre key id");
}
getPniKyberPreKeyStore().storeKyberPreKey(record.getId(), record);
pniKyberPreKeyIdOffset = (pniKyberPreKeyIdOffset + 1) % PREKEY_MAXIMUM_ID;
}
save();
}
public void addLastResortKyberPreKey(ServiceIdType serviceIdType, KyberPreKeyRecord record) {
if (serviceIdType.equals(ServiceIdType.ACI)) {
addAciLastResortKyberPreKey(record);
} else {
addPniLastResortKyberPreKey(record);
}
}
public void addAciLastResortKyberPreKey(KyberPreKeyRecord record) {
if (aciKyberPreKeyIdOffset != record.getId()) {
logger.error("Invalid last resort kyber pre key id {}, expected {}",
record.getId(),
aciKyberPreKeyIdOffset);
throw new AssertionError("Invalid last resort kyber pre key id");
}
getAciKyberPreKeyStore().storeLastResortKyberPreKey(record.getId(), record);
aciActiveLastResortKyberPreKeyId = record.getId();
aciKyberPreKeyIdOffset = (aciKyberPreKeyIdOffset + 1) % PREKEY_MAXIMUM_ID;
save();
}
public void addPniLastResortKyberPreKey(KyberPreKeyRecord record) {
if (pniKyberPreKeyIdOffset != record.getId()) {
logger.error("Invalid last resort kyber pre key id {}, expected {}",
record.getId(),
pniKyberPreKeyIdOffset);
throw new AssertionError("Invalid last resort kyber pre key id");
}
getPniKyberPreKeyStore().storeLastResortKyberPreKey(record.getId(), record);
pniActiveLastResortKyberPreKeyId = record.getId();
pniKyberPreKeyIdOffset = (pniKyberPreKeyIdOffset + 1) % PREKEY_MAXIMUM_ID;
save();
}
@ -1126,6 +1243,7 @@ public class SignalAccount implements Closeable {
return getOrCreate(() -> aciSignalProtocolStore,
() -> aciSignalProtocolStore = new SignalProtocolStore(getAciPreKeyStore(),
getAciSignedPreKeyStore(),
getAciKyberPreKeyStore(),
getAciSessionStore(),
getAciIdentityKeyStore(),
getSenderKeyStore(),
@ -1136,6 +1254,7 @@ public class SignalAccount implements Closeable {
return getOrCreate(() -> pniSignalProtocolStore,
() -> pniSignalProtocolStore = new SignalProtocolStore(getPniPreKeyStore(),
getPniSignedPreKeyStore(),
getPniKyberPreKeyStore(),
getPniSessionStore(),
getPniIdentityKeyStore(),
getSenderKeyStore(),
@ -1152,6 +1271,11 @@ public class SignalAccount implements Closeable {
() -> aciSignedPreKeyStore = new SignedPreKeyStore(getAccountDatabase(), ServiceIdType.ACI));
}
private KyberPreKeyStore getAciKyberPreKeyStore() {
return getOrCreate(() -> aciKyberPreKeyStore,
() -> aciKyberPreKeyStore = new KyberPreKeyStore(getAccountDatabase(), ServiceIdType.ACI));
}
private PreKeyStore getPniPreKeyStore() {
return getOrCreate(() -> pniPreKeyStore,
() -> pniPreKeyStore = new PreKeyStore(getAccountDatabase(), ServiceIdType.PNI));
@ -1162,6 +1286,11 @@ public class SignalAccount implements Closeable {
() -> pniSignedPreKeyStore = new SignedPreKeyStore(getAccountDatabase(), ServiceIdType.PNI));
}
private KyberPreKeyStore getPniKyberPreKeyStore() {
return getOrCreate(() -> pniKyberPreKeyStore,
() -> pniKyberPreKeyStore = new KyberPreKeyStore(getAccountDatabase(), ServiceIdType.PNI));
}
public SessionStore getAciSessionStore() {
return getOrCreate(() -> aciSessionStore,
() -> aciSessionStore = new SessionStore(getAccountDatabase(), ServiceIdType.ACI));
@ -1362,6 +1491,9 @@ public class SignalAccount implements Closeable {
identityKeyStore.deleteIdentity(this.pni);
getPniPreKeyStore().removeAllPreKeys();
getPniSignedPreKeyStore().removeAllSignedPreKeys();
getPniKyberPreKeyStore().removeAllKyberPreKeys();
aciActiveLastResortKyberPreKeyId = -1;
pniActiveLastResortKyberPreKeyId = -1;
}
this.pni = updatedPni;
@ -1406,10 +1538,6 @@ public class SignalAccount implements Closeable {
save();
}
public byte[] getEncryptedDeviceName() {
return encryptedDeviceName == null ? null : Base64.getDecoder().decode(encryptedDeviceName);
}
public void setEncryptedDeviceName(final String encryptedDeviceName) {
this.encryptedDeviceName = encryptedDeviceName;
save();
@ -1590,6 +1718,10 @@ public class SignalAccount implements Closeable {
return serviceIdType.equals(ServiceIdType.ACI) ? aciNextSignedPreKeyId : pniNextSignedPreKeyId;
}
public int getKyberPreKeyIdOffset(ServiceIdType serviceIdType) {
return serviceIdType.equals(ServiceIdType.ACI) ? aciKyberPreKeyIdOffset : pniKyberPreKeyIdOffset;
}
public boolean isRegistered() {
return registered;
}

View file

@ -0,0 +1,211 @@
package org.asamk.signal.manager.storage.prekeys;
import org.asamk.signal.manager.storage.Database;
import org.asamk.signal.manager.storage.Utils;
import org.signal.libsignal.protocol.InvalidKeyIdException;
import org.signal.libsignal.protocol.InvalidMessageException;
import org.signal.libsignal.protocol.state.KyberPreKeyRecord;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.signalservice.api.SignalServiceKyberPreKeyStore;
import org.whispersystems.signalservice.api.push.ServiceIdType;
import java.sql.Connection;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.List;
public class KyberPreKeyStore implements SignalServiceKyberPreKeyStore {
private static final String TABLE_KYBER_PRE_KEY = "kyber_pre_key";
private final static Logger logger = LoggerFactory.getLogger(KyberPreKeyStore.class);
private final Database database;
private final int accountIdType;
public static void createSql(Connection connection) throws SQLException {
// When modifying the CREATE statement here, also add a migration in AccountDatabase.java
try (final var statement = connection.createStatement()) {
statement.executeUpdate("""
CREATE TABLE kyber_pre_key (
_id INTEGER PRIMARY KEY,
account_id_type INTEGER NOT NULL,
key_id INTEGER NOT NULL,
serialized BLOB NOT NULL,
is_last_resort INTEGER NOT NULL,
UNIQUE(account_id_type, key_id)
) STRICT;
""");
}
}
public KyberPreKeyStore(final Database database, final ServiceIdType serviceIdType) {
this.database = database;
this.accountIdType = Utils.getAccountIdType(serviceIdType);
}
@Override
public KyberPreKeyRecord loadKyberPreKey(final int keyId) throws InvalidKeyIdException {
final var kyberPreKey = getPreKey(keyId);
if (kyberPreKey == null) {
throw new InvalidKeyIdException("No such kyber pre key record: " + keyId);
}
return kyberPreKey;
}
@Override
public List<KyberPreKeyRecord> loadKyberPreKeys() {
final var sql = (
"""
SELECT p.serialized
FROM %s p
WHERE p.account_id_type = ?
"""
).formatted(TABLE_KYBER_PRE_KEY);
try (final var connection = database.getConnection()) {
try (final var statement = connection.prepareStatement(sql)) {
statement.setInt(1, accountIdType);
return Utils.executeQueryForStream(statement, this::getKyberPreKeyRecordFromResultSet).toList();
}
} catch (SQLException e) {
throw new RuntimeException("Failed read from kyber_pre_key store", e);
}
}
@Override
public List<KyberPreKeyRecord> loadLastResortKyberPreKeys() {
final var sql = (
"""
SELECT p.serialized
FROM %s p
WHERE p.account_id_type = ? AND p.is_last_resort = TRUE
"""
).formatted(TABLE_KYBER_PRE_KEY);
try (final var connection = database.getConnection()) {
try (final var statement = connection.prepareStatement(sql)) {
statement.setInt(1, accountIdType);
return Utils.executeQueryForStream(statement, this::getKyberPreKeyRecordFromResultSet).toList();
}
} catch (SQLException e) {
throw new RuntimeException("Failed read from kyber_pre_key store", e);
}
}
@Override
public void storeLastResortKyberPreKey(final int keyId, final KyberPreKeyRecord record) {
storeKyberPreKey(keyId, record, true);
}
@Override
public void storeKyberPreKey(final int keyId, final KyberPreKeyRecord record) {
storeKyberPreKey(keyId, record, false);
}
public void storeKyberPreKey(final int keyId, final KyberPreKeyRecord record, final boolean isLastResort) {
final var sql = (
"""
INSERT INTO %s (account_id_type, key_id, serialized, is_last_resort)
VALUES (?, ?, ?, ?)
"""
).formatted(TABLE_KYBER_PRE_KEY);
try (final var connection = database.getConnection()) {
try (final var statement = connection.prepareStatement(sql)) {
statement.setInt(1, accountIdType);
statement.setInt(2, keyId);
statement.setBytes(3, record.serialize());
statement.setBoolean(4, isLastResort);
statement.executeUpdate();
}
} catch (SQLException e) {
throw new RuntimeException("Failed update kyber_pre_key store", e);
}
}
@Override
public boolean containsKyberPreKey(final int keyId) {
return getPreKey(keyId) != null;
}
@Override
public void markKyberPreKeyUsed(final int keyId) {
final var sql = (
"""
DELETE FROM %s AS p
WHERE p.account_id_type = ? AND p.key_id = ? AND p.is_last_resort = FALSE
"""
).formatted(TABLE_KYBER_PRE_KEY);
try (final var connection = database.getConnection()) {
try (final var statement = connection.prepareStatement(sql)) {
statement.setInt(1, accountIdType);
statement.setInt(2, keyId);
statement.executeUpdate();
}
} catch (SQLException e) {
throw new RuntimeException("Failed update kyber_pre_key store", e);
}
}
@Override
public void removeKyberPreKey(final int keyId) {
final var sql = (
"""
DELETE FROM %s AS p
WHERE p.account_id_type = ? AND p.key_id = ?
"""
).formatted(TABLE_KYBER_PRE_KEY);
try (final var connection = database.getConnection()) {
try (final var statement = connection.prepareStatement(sql)) {
statement.setInt(1, accountIdType);
statement.setInt(2, keyId);
statement.executeUpdate();
}
} catch (SQLException e) {
throw new RuntimeException("Failed update kyber_pre_key store", e);
}
}
public void removeAllKyberPreKeys() {
final var sql = (
"""
DELETE FROM %s AS p
WHERE p.account_id_type = ?
"""
).formatted(TABLE_KYBER_PRE_KEY);
try (final var connection = database.getConnection()) {
try (final var statement = connection.prepareStatement(sql)) {
statement.setInt(1, accountIdType);
statement.executeUpdate();
}
} catch (SQLException e) {
throw new RuntimeException("Failed update kyber_pre_key store", e);
}
}
private KyberPreKeyRecord getPreKey(int keyId) {
final var sql = (
"""
SELECT p.serialized
FROM %s p
WHERE p.account_id_type = ? AND p.key_id = ?
"""
).formatted(TABLE_KYBER_PRE_KEY);
try (final var connection = database.getConnection()) {
try (final var statement = connection.prepareStatement(sql)) {
statement.setInt(1, accountIdType);
statement.setInt(2, keyId);
return Utils.executeQueryForOptional(statement, this::getKyberPreKeyRecordFromResultSet).orElse(null);
}
} catch (SQLException e) {
throw new RuntimeException("Failed read from kyber_pre_key store", e);
}
}
private KyberPreKeyRecord getKyberPreKeyRecordFromResultSet(ResultSet resultSet) throws SQLException {
try {
final var serialized = resultSet.getBytes("serialized");
return new KyberPreKeyRecord(serialized);
} catch (InvalidMessageException e) {
return null;
}
}
}

View file

@ -49,7 +49,7 @@ public class PreKeyStore implements org.signal.libsignal.protocol.state.PreKeySt
public PreKeyRecord loadPreKey(int preKeyId) throws InvalidKeyIdException {
final var preKey = getPreKey(preKeyId);
if (preKey == null) {
throw new InvalidKeyIdException("No such signed pre key record: " + preKeyId);
throw new InvalidKeyIdException("No such pre key record: " + preKeyId);
}
return preKey;
}

View file

@ -14,6 +14,7 @@ import org.signal.libsignal.protocol.state.SessionRecord;
import org.signal.libsignal.protocol.state.SignedPreKeyRecord;
import org.signal.libsignal.protocol.state.SignedPreKeyStore;
import org.whispersystems.signalservice.api.SignalServiceAccountDataStore;
import org.whispersystems.signalservice.api.SignalServiceKyberPreKeyStore;
import org.whispersystems.signalservice.api.SignalServiceSenderKeyStore;
import org.whispersystems.signalservice.api.SignalServiceSessionStore;
import org.whispersystems.signalservice.api.push.DistributionId;
@ -28,6 +29,7 @@ public class SignalProtocolStore implements SignalServiceAccountDataStore {
private final PreKeyStore preKeyStore;
private final SignedPreKeyStore signedPreKeyStore;
private final SignalServiceKyberPreKeyStore kyberPreKeyStore;
private final SignalServiceSessionStore sessionStore;
private final IdentityKeyStore identityKeyStore;
private final SignalServiceSenderKeyStore senderKeyStore;
@ -36,6 +38,7 @@ public class SignalProtocolStore implements SignalServiceAccountDataStore {
public SignalProtocolStore(
final PreKeyStore preKeyStore,
final SignedPreKeyStore signedPreKeyStore,
final SignalServiceKyberPreKeyStore kyberPreKeyStore,
final SignalServiceSessionStore sessionStore,
final IdentityKeyStore identityKeyStore,
final SignalServiceSenderKeyStore senderKeyStore,
@ -43,6 +46,7 @@ public class SignalProtocolStore implements SignalServiceAccountDataStore {
) {
this.preKeyStore = preKeyStore;
this.signedPreKeyStore = signedPreKeyStore;
this.kyberPreKeyStore = kyberPreKeyStore;
this.sessionStore = sessionStore;
this.identityKeyStore = identityKeyStore;
this.senderKeyStore = senderKeyStore;
@ -201,45 +205,41 @@ public class SignalProtocolStore implements SignalServiceAccountDataStore {
@Override
public KyberPreKeyRecord loadKyberPreKey(final int kyberPreKeyId) throws InvalidKeyIdException {
// TODO
throw new InvalidKeyIdException("Missing kyber prekey with ID: $kyberPreKeyId");
return kyberPreKeyStore.loadKyberPreKey(kyberPreKeyId);
}
@Override
public List<KyberPreKeyRecord> loadKyberPreKeys() {
// TODO
return List.of();
return kyberPreKeyStore.loadKyberPreKeys();
}
@Override
public void storeKyberPreKey(final int kyberPreKeyId, final KyberPreKeyRecord record) {
// TODO
kyberPreKeyStore.storeKyberPreKey(kyberPreKeyId, record);
}
@Override
public boolean containsKyberPreKey(final int kyberPreKeyId) {
// TODO
return false;
return kyberPreKeyStore.containsKyberPreKey(kyberPreKeyId);
}
@Override
public void markKyberPreKeyUsed(final int kyberPreKeyId) {
// TODO
kyberPreKeyStore.markKyberPreKeyUsed(kyberPreKeyId);
}
@Override
public List<KyberPreKeyRecord> loadLastResortKyberPreKeys() {
// TODO
return List.of();
return kyberPreKeyStore.loadLastResortKyberPreKeys();
}
@Override
public void removeKyberPreKey(final int i) {
// TODO
kyberPreKeyStore.removeKyberPreKey(i);
}
@Override
public void storeLastResortKyberPreKey(final int i, final KyberPreKeyRecord kyberPreKeyRecord) {
// TODO
kyberPreKeyStore.storeLastResortKyberPreKey(i, kyberPreKeyRecord);
}
}

View file

@ -406,9 +406,7 @@ public class SessionStore implements SignalServiceSessionStore {
}
private static boolean isActive(SessionRecord record) {
return record != null
&& record.hasSenderChain()
&& record.getSessionVersion() == CiphertextMessage.CURRENT_VERSION;
return record != null && record.hasSenderChain();
}
record Key(ServiceId serviceId, int deviceId) {}

View file

@ -5,9 +5,11 @@ import org.signal.libsignal.protocol.IdentityKeyPair;
import org.signal.libsignal.protocol.InvalidKeyException;
import org.signal.libsignal.protocol.ecc.Curve;
import org.signal.libsignal.protocol.ecc.ECPrivateKey;
import org.signal.libsignal.protocol.kem.KEMKeyPair;
import org.signal.libsignal.protocol.kem.KEMKeyType;
import org.signal.libsignal.protocol.state.KyberPreKeyRecord;
import org.signal.libsignal.protocol.state.PreKeyRecord;
import org.signal.libsignal.protocol.state.SignedPreKeyRecord;
import org.signal.libsignal.protocol.util.Medium;
import org.signal.libsignal.zkgroup.InvalidInputException;
import org.signal.libsignal.zkgroup.profiles.ProfileKey;
import org.whispersystems.signalservice.api.kbs.MasterKey;
@ -17,6 +19,8 @@ import java.util.ArrayList;
import java.util.Base64;
import java.util.List;
import static org.asamk.signal.manager.config.ServiceConfig.PREKEY_MAXIMUM_ID;
public class KeyUtils {
private static final SecureRandom secureRandom = new SecureRandom();
@ -46,7 +50,7 @@ public class KeyUtils {
public static List<PreKeyRecord> generatePreKeyRecords(final int offset, final int batchSize) {
var records = new ArrayList<PreKeyRecord>(batchSize);
for (var i = 0; i < batchSize; i++) {
var preKeyId = (offset + i) % Medium.MAX_VALUE;
var preKeyId = (offset + i) % PREKEY_MAXIMUM_ID;
var keyPair = Curve.generateKeyPair();
var record = new PreKeyRecord(preKeyId, keyPair);
@ -68,6 +72,24 @@ public class KeyUtils {
return new SignedPreKeyRecord(signedPreKeyId, System.currentTimeMillis(), keyPair, signature);
}
public static List<KyberPreKeyRecord> generateKyberPreKeyRecords(
final int offset, final int batchSize, final ECPrivateKey privateKey
) {
var records = new ArrayList<KyberPreKeyRecord>(batchSize);
for (var i = 0; i < batchSize; i++) {
var preKeyId = (offset + i) % PREKEY_MAXIMUM_ID;
records.add(generateKyberPreKeyRecord(preKeyId, privateKey));
}
return records;
}
public static KyberPreKeyRecord generateKyberPreKeyRecord(final int preKeyId, final ECPrivateKey privateKey) {
KEMKeyPair keyPair = KEMKeyPair.generate(KEMKeyType.KYBER_1024);
byte[] signature = privateKey.calculateSignature(keyPair.getPublicKey().serialize());
return new KyberPreKeyRecord(preKeyId, System.currentTimeMillis(), keyPair, signature);
}
public static ProfileKey createProfileKey() {
try {
return new ProfileKey(getSecretBytes(32));