Implement support for kyber pre keys

This commit is contained in:
AsamK 2023-06-17 21:18:24 +02:00
parent 4e5c859aab
commit 306e38c9ee
11 changed files with 504 additions and 45 deletions

View file

@ -46,7 +46,7 @@
}, },
{ {
"name":"org.asamk.signal.manager.storage.protocol.SignalProtocolStore", "name":"org.asamk.signal.manager.storage.protocol.SignalProtocolStore",
"methods":[{"name":"getIdentity","parameterTypes":["org.signal.libsignal.protocol.SignalProtocolAddress"] }, {"name":"getIdentityKeyPair","parameterTypes":[] }, {"name":"getLocalRegistrationId","parameterTypes":[] }, {"name":"isTrustedIdentity","parameterTypes":["org.signal.libsignal.protocol.SignalProtocolAddress","org.signal.libsignal.protocol.IdentityKey","org.signal.libsignal.protocol.state.IdentityKeyStore$Direction"] }, {"name":"loadPreKey","parameterTypes":["int"] }, {"name":"loadSenderKey","parameterTypes":["org.signal.libsignal.protocol.SignalProtocolAddress","java.util.UUID"] }, {"name":"loadSession","parameterTypes":["org.signal.libsignal.protocol.SignalProtocolAddress"] }, {"name":"loadSignedPreKey","parameterTypes":["int"] }, {"name":"removePreKey","parameterTypes":["int"] }, {"name":"saveIdentity","parameterTypes":["org.signal.libsignal.protocol.SignalProtocolAddress","org.signal.libsignal.protocol.IdentityKey"] }, {"name":"storeSenderKey","parameterTypes":["org.signal.libsignal.protocol.SignalProtocolAddress","java.util.UUID","org.signal.libsignal.protocol.groups.state.SenderKeyRecord"] }, {"name":"storeSession","parameterTypes":["org.signal.libsignal.protocol.SignalProtocolAddress","org.signal.libsignal.protocol.state.SessionRecord"] }] "methods":[{"name":"getIdentity","parameterTypes":["org.signal.libsignal.protocol.SignalProtocolAddress"] }, {"name":"getIdentityKeyPair","parameterTypes":[] }, {"name":"getLocalRegistrationId","parameterTypes":[] }, {"name":"isTrustedIdentity","parameterTypes":["org.signal.libsignal.protocol.SignalProtocolAddress","org.signal.libsignal.protocol.IdentityKey","org.signal.libsignal.protocol.state.IdentityKeyStore$Direction"] }, {"name":"loadKyberPreKey","parameterTypes":["int"] }, {"name":"loadPreKey","parameterTypes":["int"] }, {"name":"loadSenderKey","parameterTypes":["org.signal.libsignal.protocol.SignalProtocolAddress","java.util.UUID"] }, {"name":"loadSession","parameterTypes":["org.signal.libsignal.protocol.SignalProtocolAddress"] }, {"name":"loadSignedPreKey","parameterTypes":["int"] }, {"name":"markKyberPreKeyUsed","parameterTypes":["int"] }, {"name":"removePreKey","parameterTypes":["int"] }, {"name":"saveIdentity","parameterTypes":["org.signal.libsignal.protocol.SignalProtocolAddress","org.signal.libsignal.protocol.IdentityKey"] }, {"name":"storeSenderKey","parameterTypes":["org.signal.libsignal.protocol.SignalProtocolAddress","java.util.UUID","org.signal.libsignal.protocol.groups.state.SenderKeyRecord"] }, {"name":"storeSession","parameterTypes":["org.signal.libsignal.protocol.SignalProtocolAddress","org.signal.libsignal.protocol.state.SessionRecord"] }]
}, },
{ {
"name":"org.asamk.signal.manager.storage.senderKeys.SenderKeyStore", "name":"org.asamk.signal.manager.storage.senderKeys.SenderKeyStore",
@ -133,6 +133,10 @@
"name":"org.signal.libsignal.protocol.state.IdentityKeyStore$Direction", "name":"org.signal.libsignal.protocol.state.IdentityKeyStore$Direction",
"fields":[{"name":"RECEIVING"}, {"name":"SENDING"}] "fields":[{"name":"RECEIVING"}, {"name":"SENDING"}]
}, },
{
"name":"org.signal.libsignal.protocol.state.KyberPreKeyRecord",
"fields":[{"name":"unsafeHandle"}]
},
{ {
"name":"org.signal.libsignal.protocol.state.KyberPreKeyStore" "name":"org.signal.libsignal.protocol.state.KyberPreKeyStore"
}, },

View file

@ -2179,16 +2179,25 @@
"name":"org.whispersystems.signalservice.internal.push.KyberPreKeyEntity", "name":"org.whispersystems.signalservice.internal.push.KyberPreKeyEntity",
"allDeclaredFields":true, "allDeclaredFields":true,
"queryAllDeclaredMethods":true, "queryAllDeclaredMethods":true,
"queryAllDeclaredConstructors":true "queryAllDeclaredConstructors":true,
"methods":[{"name":"<init>","parameterTypes":[] }, {"name":"getKeyId","parameterTypes":[] }, {"name":"getPublicKey","parameterTypes":[] }, {"name":"getSignature","parameterTypes":[] }]
}, },
{ {
"name":"org.whispersystems.signalservice.internal.push.KyberPreKeyEntity$ByteArrayDeserializer", "name":"org.whispersystems.signalservice.internal.push.KyberPreKeyEntity$ByteArrayDeserializer",
"methods":[{"name":"<init>","parameterTypes":[] }] "methods":[{"name":"<init>","parameterTypes":[] }]
}, },
{
"name":"org.whispersystems.signalservice.internal.push.KyberPreKeyEntity$ByteArraySerializer",
"methods":[{"name":"<init>","parameterTypes":[] }]
},
{ {
"name":"org.whispersystems.signalservice.internal.push.KyberPreKeyEntity$KEMPublicKeyDeserializer", "name":"org.whispersystems.signalservice.internal.push.KyberPreKeyEntity$KEMPublicKeyDeserializer",
"methods":[{"name":"<init>","parameterTypes":[] }] "methods":[{"name":"<init>","parameterTypes":[] }]
}, },
{
"name":"org.whispersystems.signalservice.internal.push.KyberPreKeyEntity$KEMPublicKeySerializer",
"methods":[{"name":"<init>","parameterTypes":[] }]
},
{ {
"name":"org.whispersystems.signalservice.internal.push.MismatchedDevices", "name":"org.whispersystems.signalservice.internal.push.MismatchedDevices",
"allDeclaredFields":true, "allDeclaredFields":true,
@ -2250,7 +2259,8 @@
"name":"org.whispersystems.signalservice.internal.push.PreKeyState", "name":"org.whispersystems.signalservice.internal.push.PreKeyState",
"allDeclaredFields":true, "allDeclaredFields":true,
"allDeclaredMethods":true, "allDeclaredMethods":true,
"allDeclaredConstructors":true "allDeclaredConstructors":true,
"methods":[{"name":"getIdentityKey","parameterTypes":[] }, {"name":"getPreKeys","parameterTypes":[] }, {"name":"getSignedPreKey","parameterTypes":[] }]
}, },
{ {
"name":"org.whispersystems.signalservice.internal.push.PreKeyStatus", "name":"org.whispersystems.signalservice.internal.push.PreKeyStatus",

View file

@ -1,6 +1,7 @@
package org.asamk.signal.manager.config; package org.asamk.signal.manager.config;
import org.asamk.signal.manager.api.ServiceEnvironment; import org.asamk.signal.manager.api.ServiceEnvironment;
import org.signal.libsignal.protocol.util.Medium;
import org.whispersystems.signalservice.api.account.AccountAttributes; import org.whispersystems.signalservice.api.account.AccountAttributes;
import org.whispersystems.signalservice.api.push.TrustStore; import org.whispersystems.signalservice.api.push.TrustStore;
@ -15,8 +16,9 @@ import okhttp3.Interceptor;
public class ServiceConfig { public class ServiceConfig {
public final static int PREKEY_MINIMUM_COUNT = 20; public final static int PREKEY_MINIMUM_COUNT = 10;
public final static int PREKEY_BATCH_SIZE = 100; public final static int PREKEY_BATCH_SIZE = 100;
public final static int PREKEY_MAXIMUM_ID = Medium.MAX_VALUE;
public final static int MAX_ATTACHMENT_SIZE = 150 * 1024 * 1024; public final static int MAX_ATTACHMENT_SIZE = 150 * 1024 * 1024;
public final static long MAX_ENVELOPE_SIZE = 0; public final static long MAX_ENVELOPE_SIZE = 0;
public final static long AVATAR_DOWNLOAD_FAILSAFE_MAX_SIZE = 10 * 1024 * 1024; public final static long AVATAR_DOWNLOAD_FAILSAFE_MAX_SIZE = 10 * 1024 * 1024;

View file

@ -5,6 +5,7 @@ import org.asamk.signal.manager.internal.SignalDependencies;
import org.asamk.signal.manager.storage.SignalAccount; import org.asamk.signal.manager.storage.SignalAccount;
import org.asamk.signal.manager.util.KeyUtils; import org.asamk.signal.manager.util.KeyUtils;
import org.signal.libsignal.protocol.IdentityKeyPair; import org.signal.libsignal.protocol.IdentityKeyPair;
import org.signal.libsignal.protocol.state.KyberPreKeyRecord;
import org.signal.libsignal.protocol.state.PreKeyRecord; import org.signal.libsignal.protocol.state.PreKeyRecord;
import org.signal.libsignal.protocol.state.SignedPreKeyRecord; import org.signal.libsignal.protocol.state.SignedPreKeyRecord;
import org.slf4j.Logger; import org.slf4j.Logger;
@ -39,7 +40,9 @@ public class PreKeyHelper {
if (preKeyCounts.getEcCount() < ServiceConfig.PREKEY_MINIMUM_COUNT) { if (preKeyCounts.getEcCount() < ServiceConfig.PREKEY_MINIMUM_COUNT) {
refreshPreKeys(serviceIdType); refreshPreKeys(serviceIdType);
} }
// TODO kyber pre keys if (preKeyCounts.getKyberCount() < ServiceConfig.PREKEY_MINIMUM_COUNT) {
refreshKyberPreKeys(serviceIdType);
}
} }
public void refreshPreKeys() throws IOException { public void refreshPreKeys() throws IOException {
@ -47,7 +50,7 @@ public class PreKeyHelper {
refreshPreKeys(ServiceIdType.PNI); refreshPreKeys(ServiceIdType.PNI);
} }
public void refreshPreKeys(ServiceIdType serviceIdType) throws IOException { private void refreshPreKeys(ServiceIdType serviceIdType) throws IOException {
final var identityKeyPair = account.getIdentityKeyPair(serviceIdType); final var identityKeyPair = account.getIdentityKeyPair(serviceIdType);
if (identityKeyPair == null) { if (identityKeyPair == null) {
return; return;
@ -97,4 +100,61 @@ public class PreKeyHelper {
return record; return record;
} }
private void refreshKyberPreKeys(ServiceIdType serviceIdType) throws IOException {
final var identityKeyPair = account.getIdentityKeyPair(serviceIdType);
if (identityKeyPair == null) {
return;
}
final var accountId = account.getAccountId(serviceIdType);
if (accountId == null) {
return;
}
try {
refreshKyberPreKeys(serviceIdType, identityKeyPair);
} catch (Exception e) {
logger.warn("Failed to store new pre keys, resetting preKey id offset", e);
account.resetKyberPreKeyOffsets(serviceIdType);
refreshKyberPreKeys(serviceIdType, identityKeyPair);
}
}
private void refreshKyberPreKeys(
final ServiceIdType serviceIdType, final IdentityKeyPair identityKeyPair
) throws IOException {
final var oneTimePreKeys = generateKyberPreKeys(serviceIdType, identityKeyPair);
final var lastResortPreKeyRecord = generateLastResortKyberPreKey(serviceIdType, identityKeyPair);
final var preKeyUpload = new PreKeyUpload(serviceIdType,
identityKeyPair.getPublicKey(),
null,
null,
lastResortPreKeyRecord,
oneTimePreKeys);
dependencies.getAccountManager().setPreKeys(preKeyUpload);
}
private List<KyberPreKeyRecord> generateKyberPreKeys(
ServiceIdType serviceIdType, final IdentityKeyPair identityKeyPair
) {
final var offset = account.getKyberPreKeyIdOffset(serviceIdType);
var records = KeyUtils.generateKyberPreKeyRecords(offset,
ServiceConfig.PREKEY_BATCH_SIZE,
identityKeyPair.getPrivateKey());
account.addKyberPreKeys(serviceIdType, records);
return records;
}
private KyberPreKeyRecord generateLastResortKyberPreKey(
ServiceIdType serviceIdType, IdentityKeyPair identityKeyPair
) {
final var signedPreKeyId = account.getKyberPreKeyIdOffset(serviceIdType);
var record = KeyUtils.generateKyberPreKeyRecord(signedPreKeyId, identityKeyPair.getPrivateKey());
account.addLastResortKyberPreKey(serviceIdType, record);
return record;
}
} }

View file

@ -4,6 +4,7 @@ import com.zaxxer.hikari.HikariDataSource;
import org.asamk.signal.manager.storage.groups.GroupStore; import org.asamk.signal.manager.storage.groups.GroupStore;
import org.asamk.signal.manager.storage.identities.IdentityKeyStore; import org.asamk.signal.manager.storage.identities.IdentityKeyStore;
import org.asamk.signal.manager.storage.prekeys.KyberPreKeyStore;
import org.asamk.signal.manager.storage.prekeys.PreKeyStore; import org.asamk.signal.manager.storage.prekeys.PreKeyStore;
import org.asamk.signal.manager.storage.prekeys.SignedPreKeyStore; import org.asamk.signal.manager.storage.prekeys.SignedPreKeyStore;
import org.asamk.signal.manager.storage.recipients.RecipientStore; import org.asamk.signal.manager.storage.recipients.RecipientStore;
@ -23,7 +24,7 @@ import java.sql.SQLException;
public class AccountDatabase extends Database { public class AccountDatabase extends Database {
private final static Logger logger = LoggerFactory.getLogger(AccountDatabase.class); private final static Logger logger = LoggerFactory.getLogger(AccountDatabase.class);
private static final long DATABASE_VERSION = 13; private static final long DATABASE_VERSION = 14;
private AccountDatabase(final HikariDataSource dataSource) { private AccountDatabase(final HikariDataSource dataSource) {
super(logger, DATABASE_VERSION, dataSource); super(logger, DATABASE_VERSION, dataSource);
@ -40,6 +41,7 @@ public class AccountDatabase extends Database {
StickerStore.createSql(connection); StickerStore.createSql(connection);
PreKeyStore.createSql(connection); PreKeyStore.createSql(connection);
SignedPreKeyStore.createSql(connection); SignedPreKeyStore.createSql(connection);
KyberPreKeyStore.createSql(connection);
GroupStore.createSql(connection); GroupStore.createSql(connection);
SessionStore.createSql(connection); SessionStore.createSql(connection);
IdentityKeyStore.createSql(connection); IdentityKeyStore.createSql(connection);
@ -328,5 +330,23 @@ public class AccountDatabase extends Database {
} }
} }
} }
if (oldVersion < 14) {
logger.debug("Updating database: Creating kyber_pre_key table");
{
try (final var statement = connection.createStatement()) {
statement.executeUpdate("""
CREATE TABLE kyber_pre_key (
_id INTEGER PRIMARY KEY,
account_id_type INTEGER NOT NULL,
key_id INTEGER NOT NULL,
serialized BLOB NOT NULL,
is_last_resort INTEGER NOT NULL,
UNIQUE(account_id_type, key_id)
) STRICT;
""");
}
}
}
} }
} }

View file

@ -21,6 +21,7 @@ import org.asamk.signal.manager.storage.identities.IdentityKeyStore;
import org.asamk.signal.manager.storage.identities.LegacyIdentityKeyStore; import org.asamk.signal.manager.storage.identities.LegacyIdentityKeyStore;
import org.asamk.signal.manager.storage.identities.SignalIdentityKeyStore; import org.asamk.signal.manager.storage.identities.SignalIdentityKeyStore;
import org.asamk.signal.manager.storage.messageCache.MessageCache; import org.asamk.signal.manager.storage.messageCache.MessageCache;
import org.asamk.signal.manager.storage.prekeys.KyberPreKeyStore;
import org.asamk.signal.manager.storage.prekeys.LegacyPreKeyStore; import org.asamk.signal.manager.storage.prekeys.LegacyPreKeyStore;
import org.asamk.signal.manager.storage.prekeys.LegacySignedPreKeyStore; import org.asamk.signal.manager.storage.prekeys.LegacySignedPreKeyStore;
import org.asamk.signal.manager.storage.prekeys.PreKeyStore; import org.asamk.signal.manager.storage.prekeys.PreKeyStore;
@ -51,11 +52,11 @@ import org.asamk.signal.manager.util.KeyUtils;
import org.signal.libsignal.protocol.IdentityKeyPair; import org.signal.libsignal.protocol.IdentityKeyPair;
import org.signal.libsignal.protocol.InvalidMessageException; import org.signal.libsignal.protocol.InvalidMessageException;
import org.signal.libsignal.protocol.SignalProtocolAddress; import org.signal.libsignal.protocol.SignalProtocolAddress;
import org.signal.libsignal.protocol.state.KyberPreKeyRecord;
import org.signal.libsignal.protocol.state.PreKeyRecord; import org.signal.libsignal.protocol.state.PreKeyRecord;
import org.signal.libsignal.protocol.state.SessionRecord; import org.signal.libsignal.protocol.state.SessionRecord;
import org.signal.libsignal.protocol.state.SignedPreKeyRecord; import org.signal.libsignal.protocol.state.SignedPreKeyRecord;
import org.signal.libsignal.protocol.util.KeyHelper; import org.signal.libsignal.protocol.util.KeyHelper;
import org.signal.libsignal.protocol.util.Medium;
import org.signal.libsignal.zkgroup.InvalidInputException; import org.signal.libsignal.zkgroup.InvalidInputException;
import org.signal.libsignal.zkgroup.profiles.ProfileKey; import org.signal.libsignal.zkgroup.profiles.ProfileKey;
import org.slf4j.Logger; import org.slf4j.Logger;
@ -98,6 +99,7 @@ import java.util.List;
import java.util.Optional; import java.util.Optional;
import java.util.function.Supplier; import java.util.function.Supplier;
import static org.asamk.signal.manager.config.ServiceConfig.PREKEY_MAXIMUM_ID;
import static org.asamk.signal.manager.config.ServiceConfig.getCapabilities; import static org.asamk.signal.manager.config.ServiceConfig.getCapabilities;
public class SignalAccount implements Closeable { public class SignalAccount implements Closeable {
@ -105,7 +107,7 @@ public class SignalAccount implements Closeable {
private final static Logger logger = LoggerFactory.getLogger(SignalAccount.class); private final static Logger logger = LoggerFactory.getLogger(SignalAccount.class);
private static final int MINIMUM_STORAGE_VERSION = 1; private static final int MINIMUM_STORAGE_VERSION = 1;
private static final int CURRENT_STORAGE_VERSION = 6; private static final int CURRENT_STORAGE_VERSION = 7;
private final Object LOCK = new Object(); private final Object LOCK = new Object();
@ -138,6 +140,10 @@ public class SignalAccount implements Closeable {
private int aciNextSignedPreKeyId = 1; private int aciNextSignedPreKeyId = 1;
private int pniPreKeyIdOffset = 1; private int pniPreKeyIdOffset = 1;
private int pniNextSignedPreKeyId = 1; private int pniNextSignedPreKeyId = 1;
private int aciKyberPreKeyIdOffset = 1;
private int aciActiveLastResortKyberPreKeyId = -1;
private int pniKyberPreKeyIdOffset = 1;
private int pniActiveLastResortKyberPreKeyId = -1;
private IdentityKeyPair aciIdentityKeyPair; private IdentityKeyPair aciIdentityKeyPair;
private IdentityKeyPair pniIdentityKeyPair; private IdentityKeyPair pniIdentityKeyPair;
private int localRegistrationId; private int localRegistrationId;
@ -151,8 +157,10 @@ public class SignalAccount implements Closeable {
private SignalProtocolStore pniSignalProtocolStore; private SignalProtocolStore pniSignalProtocolStore;
private PreKeyStore aciPreKeyStore; private PreKeyStore aciPreKeyStore;
private SignedPreKeyStore aciSignedPreKeyStore; private SignedPreKeyStore aciSignedPreKeyStore;
private KyberPreKeyStore aciKyberPreKeyStore;
private PreKeyStore pniPreKeyStore; private PreKeyStore pniPreKeyStore;
private SignedPreKeyStore pniSignedPreKeyStore; private SignedPreKeyStore pniSignedPreKeyStore;
private KyberPreKeyStore pniKyberPreKeyStore;
private SessionStore aciSessionStore; private SessionStore aciSessionStore;
private SessionStore pniSessionStore; private SessionStore pniSessionStore;
private IdentityKeyStore identityKeyStore; private IdentityKeyStore identityKeyStore;
@ -302,10 +310,14 @@ public class SignalAccount implements Closeable {
private void clearAllPreKeys() { private void clearAllPreKeys() {
resetPreKeyOffsets(ServiceIdType.ACI); resetPreKeyOffsets(ServiceIdType.ACI);
resetPreKeyOffsets(ServiceIdType.PNI); resetPreKeyOffsets(ServiceIdType.PNI);
resetKyberPreKeyOffsets(ServiceIdType.ACI);
resetKyberPreKeyOffsets(ServiceIdType.PNI);
this.getAciPreKeyStore().removeAllPreKeys(); this.getAciPreKeyStore().removeAllPreKeys();
this.getAciSignedPreKeyStore().removeAllSignedPreKeys(); this.getAciSignedPreKeyStore().removeAllSignedPreKeys();
this.getAciKyberPreKeyStore().removeAllKyberPreKeys();
this.getPniPreKeyStore().removeAllPreKeys(); this.getPniPreKeyStore().removeAllPreKeys();
this.getPniSignedPreKeyStore().removeAllSignedPreKeys(); this.getPniSignedPreKeyStore().removeAllSignedPreKeys();
this.getPniKyberPreKeyStore().removeAllKyberPreKeys();
save(); save();
} }
@ -614,22 +626,42 @@ public class SignalAccount implements Closeable {
if (rootNode.hasNonNull("preKeyIdOffset")) { if (rootNode.hasNonNull("preKeyIdOffset")) {
aciPreKeyIdOffset = rootNode.get("preKeyIdOffset").asInt(1); aciPreKeyIdOffset = rootNode.get("preKeyIdOffset").asInt(1);
} else { } else {
aciPreKeyIdOffset = 1; aciPreKeyIdOffset = getRandomPreKeyIdOffset();
} }
if (rootNode.hasNonNull("nextSignedPreKeyId")) { if (rootNode.hasNonNull("nextSignedPreKeyId")) {
aciNextSignedPreKeyId = rootNode.get("nextSignedPreKeyId").asInt(1); aciNextSignedPreKeyId = rootNode.get("nextSignedPreKeyId").asInt(1);
} else { } else {
aciNextSignedPreKeyId = 1; aciNextSignedPreKeyId = getRandomPreKeyIdOffset();
} }
if (rootNode.hasNonNull("pniPreKeyIdOffset")) { if (rootNode.hasNonNull("pniPreKeyIdOffset")) {
pniPreKeyIdOffset = rootNode.get("pniPreKeyIdOffset").asInt(1); pniPreKeyIdOffset = rootNode.get("pniPreKeyIdOffset").asInt(1);
} else { } else {
pniPreKeyIdOffset = 1; pniPreKeyIdOffset = getRandomPreKeyIdOffset();
} }
if (rootNode.hasNonNull("pniNextSignedPreKeyId")) { if (rootNode.hasNonNull("pniNextSignedPreKeyId")) {
pniNextSignedPreKeyId = rootNode.get("pniNextSignedPreKeyId").asInt(1); pniNextSignedPreKeyId = rootNode.get("pniNextSignedPreKeyId").asInt(1);
} else { } else {
pniNextSignedPreKeyId = 1; pniNextSignedPreKeyId = getRandomPreKeyIdOffset();
}
if (rootNode.hasNonNull("kyberPreKeyIdOffset")) {
aciKyberPreKeyIdOffset = rootNode.get("kyberPreKeyIdOffset").asInt(1);
} else {
aciKyberPreKeyIdOffset = getRandomPreKeyIdOffset();
}
if (rootNode.hasNonNull("activeLastResortKyberPreKeyId")) {
aciActiveLastResortKyberPreKeyId = rootNode.get("activeLastResortKyberPreKeyId").asInt(-1);
} else {
aciActiveLastResortKyberPreKeyId = -1;
}
if (rootNode.hasNonNull("pniKyberPreKeyIdOffset")) {
pniKyberPreKeyIdOffset = rootNode.get("pniKyberPreKeyIdOffset").asInt(1);
} else {
pniKyberPreKeyIdOffset = getRandomPreKeyIdOffset();
}
if (rootNode.hasNonNull("pniActiveLastResortKyberPreKeyId")) {
pniActiveLastResortKyberPreKeyId = rootNode.get("pniActiveLastResortKyberPreKeyId").asInt(-1);
} else {
pniActiveLastResortKyberPreKeyId = -1;
} }
if (rootNode.hasNonNull("profileKey")) { if (rootNode.hasNonNull("profileKey")) {
try { try {
@ -974,6 +1006,10 @@ public class SignalAccount implements Closeable {
.put("nextSignedPreKeyId", aciNextSignedPreKeyId) .put("nextSignedPreKeyId", aciNextSignedPreKeyId)
.put("pniPreKeyIdOffset", pniPreKeyIdOffset) .put("pniPreKeyIdOffset", pniPreKeyIdOffset)
.put("pniNextSignedPreKeyId", pniNextSignedPreKeyId) .put("pniNextSignedPreKeyId", pniNextSignedPreKeyId)
.put("kyberPreKeyIdOffset", aciKyberPreKeyIdOffset)
.put("activeLastResortKyberPreKeyId", aciActiveLastResortKyberPreKeyId)
.put("pniKyberPreKeyIdOffset", pniKyberPreKeyIdOffset)
.put("pniActiveLastResortKyberPreKeyId", pniActiveLastResortKyberPreKeyId)
.put("profileKey", .put("profileKey",
profileKey == null ? null : Base64.getEncoder().encodeToString(profileKey.serialize())) profileKey == null ? null : Base64.getEncoder().encodeToString(profileKey.serialize()))
.put("registered", registered) .put("registered", registered)
@ -1019,15 +1055,19 @@ public class SignalAccount implements Closeable {
public void resetPreKeyOffsets(final ServiceIdType serviceIdType) { public void resetPreKeyOffsets(final ServiceIdType serviceIdType) {
if (serviceIdType.equals(ServiceIdType.ACI)) { if (serviceIdType.equals(ServiceIdType.ACI)) {
this.aciPreKeyIdOffset = new SecureRandom().nextInt(Medium.MAX_VALUE); this.aciPreKeyIdOffset = getRandomPreKeyIdOffset();
this.aciNextSignedPreKeyId = new SecureRandom().nextInt(Medium.MAX_VALUE); this.aciNextSignedPreKeyId = getRandomPreKeyIdOffset();
} else { } else {
this.pniPreKeyIdOffset = new SecureRandom().nextInt(Medium.MAX_VALUE); this.pniPreKeyIdOffset = getRandomPreKeyIdOffset();
this.pniNextSignedPreKeyId = new SecureRandom().nextInt(Medium.MAX_VALUE); this.pniNextSignedPreKeyId = getRandomPreKeyIdOffset();
} }
save(); save();
} }
private static int getRandomPreKeyIdOffset() {
return new SecureRandom().nextInt(PREKEY_MAXIMUM_ID);
}
public void addPreKeys(ServiceIdType serviceIdType, List<PreKeyRecord> records) { public void addPreKeys(ServiceIdType serviceIdType, List<PreKeyRecord> records) {
if (serviceIdType.equals(ServiceIdType.ACI)) { if (serviceIdType.equals(ServiceIdType.ACI)) {
addAciPreKeys(records); addAciPreKeys(records);
@ -1036,26 +1076,26 @@ public class SignalAccount implements Closeable {
} }
} }
public void addAciPreKeys(List<PreKeyRecord> records) { private void addAciPreKeys(List<PreKeyRecord> records) {
for (var record : records) { for (var record : records) {
if (aciPreKeyIdOffset != record.getId()) { if (aciPreKeyIdOffset != record.getId()) {
logger.error("Invalid pre key id {}, expected {}", record.getId(), aciPreKeyIdOffset); logger.error("Invalid pre key id {}, expected {}", record.getId(), aciPreKeyIdOffset);
throw new AssertionError("Invalid pre key id"); throw new AssertionError("Invalid pre key id");
} }
getAciPreKeyStore().storePreKey(record.getId(), record); getAciPreKeyStore().storePreKey(record.getId(), record);
aciPreKeyIdOffset = (aciPreKeyIdOffset + 1) % Medium.MAX_VALUE; aciPreKeyIdOffset = (aciPreKeyIdOffset + 1) % PREKEY_MAXIMUM_ID;
} }
save(); save();
} }
public void addPniPreKeys(List<PreKeyRecord> records) { private void addPniPreKeys(List<PreKeyRecord> records) {
for (var record : records) { for (var record : records) {
if (pniPreKeyIdOffset != record.getId()) { if (pniPreKeyIdOffset != record.getId()) {
logger.error("Invalid pre key id {}, expected {}", record.getId(), pniPreKeyIdOffset); logger.error("Invalid pre key id {}, expected {}", record.getId(), pniPreKeyIdOffset);
throw new AssertionError("Invalid pre key id"); throw new AssertionError("Invalid pre key id");
} }
getPniPreKeyStore().storePreKey(record.getId(), record); getPniPreKeyStore().storePreKey(record.getId(), record);
pniPreKeyIdOffset = (pniPreKeyIdOffset + 1) % Medium.MAX_VALUE; pniPreKeyIdOffset = (pniPreKeyIdOffset + 1) % PREKEY_MAXIMUM_ID;
} }
save(); save();
} }
@ -1074,7 +1114,7 @@ public class SignalAccount implements Closeable {
throw new AssertionError("Invalid signed pre key id"); throw new AssertionError("Invalid signed pre key id");
} }
getAciSignedPreKeyStore().storeSignedPreKey(record.getId(), record); getAciSignedPreKeyStore().storeSignedPreKey(record.getId(), record);
aciNextSignedPreKeyId = (aciNextSignedPreKeyId + 1) % Medium.MAX_VALUE; aciNextSignedPreKeyId = (aciNextSignedPreKeyId + 1) % PREKEY_MAXIMUM_ID;
save(); save();
} }
@ -1084,7 +1124,84 @@ public class SignalAccount implements Closeable {
throw new AssertionError("Invalid signed pre key id"); throw new AssertionError("Invalid signed pre key id");
} }
getPniSignedPreKeyStore().storeSignedPreKey(record.getId(), record); getPniSignedPreKeyStore().storeSignedPreKey(record.getId(), record);
pniNextSignedPreKeyId = (pniNextSignedPreKeyId + 1) % Medium.MAX_VALUE; pniNextSignedPreKeyId = (pniNextSignedPreKeyId + 1) % PREKEY_MAXIMUM_ID;
save();
}
public void resetKyberPreKeyOffsets(final ServiceIdType serviceIdType) {
if (serviceIdType.equals(ServiceIdType.ACI)) {
this.aciKyberPreKeyIdOffset = getRandomPreKeyIdOffset();
this.aciActiveLastResortKyberPreKeyId = -1;
} else {
this.pniKyberPreKeyIdOffset = getRandomPreKeyIdOffset();
this.pniActiveLastResortKyberPreKeyId = -1;
}
save();
}
public void addKyberPreKeys(ServiceIdType serviceIdType, List<KyberPreKeyRecord> records) {
if (serviceIdType.equals(ServiceIdType.ACI)) {
addAciKyberPreKeys(records);
} else {
addPniKyberPreKeys(records);
}
}
private void addAciKyberPreKeys(List<KyberPreKeyRecord> records) {
for (var record : records) {
if (aciKyberPreKeyIdOffset != record.getId()) {
logger.error("Invalid kyber pre key id {}, expected {}", record.getId(), aciKyberPreKeyIdOffset);
throw new AssertionError("Invalid kyber pre key id");
}
getAciKyberPreKeyStore().storeKyberPreKey(record.getId(), record);
aciKyberPreKeyIdOffset = (aciKyberPreKeyIdOffset + 1) % PREKEY_MAXIMUM_ID;
}
save();
}
private void addPniKyberPreKeys(List<KyberPreKeyRecord> records) {
for (var record : records) {
if (pniKyberPreKeyIdOffset != record.getId()) {
logger.error("Invalid kyber pre key id {}, expected {}", record.getId(), pniKyberPreKeyIdOffset);
throw new AssertionError("Invalid kyber pre key id");
}
getPniKyberPreKeyStore().storeKyberPreKey(record.getId(), record);
pniKyberPreKeyIdOffset = (pniKyberPreKeyIdOffset + 1) % PREKEY_MAXIMUM_ID;
}
save();
}
public void addLastResortKyberPreKey(ServiceIdType serviceIdType, KyberPreKeyRecord record) {
if (serviceIdType.equals(ServiceIdType.ACI)) {
addAciLastResortKyberPreKey(record);
} else {
addPniLastResortKyberPreKey(record);
}
}
public void addAciLastResortKyberPreKey(KyberPreKeyRecord record) {
if (aciKyberPreKeyIdOffset != record.getId()) {
logger.error("Invalid last resort kyber pre key id {}, expected {}",
record.getId(),
aciKyberPreKeyIdOffset);
throw new AssertionError("Invalid last resort kyber pre key id");
}
getAciKyberPreKeyStore().storeLastResortKyberPreKey(record.getId(), record);
aciActiveLastResortKyberPreKeyId = record.getId();
aciKyberPreKeyIdOffset = (aciKyberPreKeyIdOffset + 1) % PREKEY_MAXIMUM_ID;
save();
}
public void addPniLastResortKyberPreKey(KyberPreKeyRecord record) {
if (pniKyberPreKeyIdOffset != record.getId()) {
logger.error("Invalid last resort kyber pre key id {}, expected {}",
record.getId(),
pniKyberPreKeyIdOffset);
throw new AssertionError("Invalid last resort kyber pre key id");
}
getPniKyberPreKeyStore().storeLastResortKyberPreKey(record.getId(), record);
pniActiveLastResortKyberPreKeyId = record.getId();
pniKyberPreKeyIdOffset = (pniKyberPreKeyIdOffset + 1) % PREKEY_MAXIMUM_ID;
save(); save();
} }
@ -1126,6 +1243,7 @@ public class SignalAccount implements Closeable {
return getOrCreate(() -> aciSignalProtocolStore, return getOrCreate(() -> aciSignalProtocolStore,
() -> aciSignalProtocolStore = new SignalProtocolStore(getAciPreKeyStore(), () -> aciSignalProtocolStore = new SignalProtocolStore(getAciPreKeyStore(),
getAciSignedPreKeyStore(), getAciSignedPreKeyStore(),
getAciKyberPreKeyStore(),
getAciSessionStore(), getAciSessionStore(),
getAciIdentityKeyStore(), getAciIdentityKeyStore(),
getSenderKeyStore(), getSenderKeyStore(),
@ -1136,6 +1254,7 @@ public class SignalAccount implements Closeable {
return getOrCreate(() -> pniSignalProtocolStore, return getOrCreate(() -> pniSignalProtocolStore,
() -> pniSignalProtocolStore = new SignalProtocolStore(getPniPreKeyStore(), () -> pniSignalProtocolStore = new SignalProtocolStore(getPniPreKeyStore(),
getPniSignedPreKeyStore(), getPniSignedPreKeyStore(),
getPniKyberPreKeyStore(),
getPniSessionStore(), getPniSessionStore(),
getPniIdentityKeyStore(), getPniIdentityKeyStore(),
getSenderKeyStore(), getSenderKeyStore(),
@ -1152,6 +1271,11 @@ public class SignalAccount implements Closeable {
() -> aciSignedPreKeyStore = new SignedPreKeyStore(getAccountDatabase(), ServiceIdType.ACI)); () -> aciSignedPreKeyStore = new SignedPreKeyStore(getAccountDatabase(), ServiceIdType.ACI));
} }
private KyberPreKeyStore getAciKyberPreKeyStore() {
return getOrCreate(() -> aciKyberPreKeyStore,
() -> aciKyberPreKeyStore = new KyberPreKeyStore(getAccountDatabase(), ServiceIdType.ACI));
}
private PreKeyStore getPniPreKeyStore() { private PreKeyStore getPniPreKeyStore() {
return getOrCreate(() -> pniPreKeyStore, return getOrCreate(() -> pniPreKeyStore,
() -> pniPreKeyStore = new PreKeyStore(getAccountDatabase(), ServiceIdType.PNI)); () -> pniPreKeyStore = new PreKeyStore(getAccountDatabase(), ServiceIdType.PNI));
@ -1162,6 +1286,11 @@ public class SignalAccount implements Closeable {
() -> pniSignedPreKeyStore = new SignedPreKeyStore(getAccountDatabase(), ServiceIdType.PNI)); () -> pniSignedPreKeyStore = new SignedPreKeyStore(getAccountDatabase(), ServiceIdType.PNI));
} }
private KyberPreKeyStore getPniKyberPreKeyStore() {
return getOrCreate(() -> pniKyberPreKeyStore,
() -> pniKyberPreKeyStore = new KyberPreKeyStore(getAccountDatabase(), ServiceIdType.PNI));
}
public SessionStore getAciSessionStore() { public SessionStore getAciSessionStore() {
return getOrCreate(() -> aciSessionStore, return getOrCreate(() -> aciSessionStore,
() -> aciSessionStore = new SessionStore(getAccountDatabase(), ServiceIdType.ACI)); () -> aciSessionStore = new SessionStore(getAccountDatabase(), ServiceIdType.ACI));
@ -1362,6 +1491,9 @@ public class SignalAccount implements Closeable {
identityKeyStore.deleteIdentity(this.pni); identityKeyStore.deleteIdentity(this.pni);
getPniPreKeyStore().removeAllPreKeys(); getPniPreKeyStore().removeAllPreKeys();
getPniSignedPreKeyStore().removeAllSignedPreKeys(); getPniSignedPreKeyStore().removeAllSignedPreKeys();
getPniKyberPreKeyStore().removeAllKyberPreKeys();
aciActiveLastResortKyberPreKeyId = -1;
pniActiveLastResortKyberPreKeyId = -1;
} }
this.pni = updatedPni; this.pni = updatedPni;
@ -1406,10 +1538,6 @@ public class SignalAccount implements Closeable {
save(); save();
} }
public byte[] getEncryptedDeviceName() {
return encryptedDeviceName == null ? null : Base64.getDecoder().decode(encryptedDeviceName);
}
public void setEncryptedDeviceName(final String encryptedDeviceName) { public void setEncryptedDeviceName(final String encryptedDeviceName) {
this.encryptedDeviceName = encryptedDeviceName; this.encryptedDeviceName = encryptedDeviceName;
save(); save();
@ -1590,6 +1718,10 @@ public class SignalAccount implements Closeable {
return serviceIdType.equals(ServiceIdType.ACI) ? aciNextSignedPreKeyId : pniNextSignedPreKeyId; return serviceIdType.equals(ServiceIdType.ACI) ? aciNextSignedPreKeyId : pniNextSignedPreKeyId;
} }
public int getKyberPreKeyIdOffset(ServiceIdType serviceIdType) {
return serviceIdType.equals(ServiceIdType.ACI) ? aciKyberPreKeyIdOffset : pniKyberPreKeyIdOffset;
}
public boolean isRegistered() { public boolean isRegistered() {
return registered; return registered;
} }

View file

@ -0,0 +1,211 @@
package org.asamk.signal.manager.storage.prekeys;
import org.asamk.signal.manager.storage.Database;
import org.asamk.signal.manager.storage.Utils;
import org.signal.libsignal.protocol.InvalidKeyIdException;
import org.signal.libsignal.protocol.InvalidMessageException;
import org.signal.libsignal.protocol.state.KyberPreKeyRecord;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.signalservice.api.SignalServiceKyberPreKeyStore;
import org.whispersystems.signalservice.api.push.ServiceIdType;
import java.sql.Connection;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.List;
public class KyberPreKeyStore implements SignalServiceKyberPreKeyStore {
private static final String TABLE_KYBER_PRE_KEY = "kyber_pre_key";
private final static Logger logger = LoggerFactory.getLogger(KyberPreKeyStore.class);
private final Database database;
private final int accountIdType;
public static void createSql(Connection connection) throws SQLException {
// When modifying the CREATE statement here, also add a migration in AccountDatabase.java
try (final var statement = connection.createStatement()) {
statement.executeUpdate("""
CREATE TABLE kyber_pre_key (
_id INTEGER PRIMARY KEY,
account_id_type INTEGER NOT NULL,
key_id INTEGER NOT NULL,
serialized BLOB NOT NULL,
is_last_resort INTEGER NOT NULL,
UNIQUE(account_id_type, key_id)
) STRICT;
""");
}
}
public KyberPreKeyStore(final Database database, final ServiceIdType serviceIdType) {
this.database = database;
this.accountIdType = Utils.getAccountIdType(serviceIdType);
}
@Override
public KyberPreKeyRecord loadKyberPreKey(final int keyId) throws InvalidKeyIdException {
final var kyberPreKey = getPreKey(keyId);
if (kyberPreKey == null) {
throw new InvalidKeyIdException("No such kyber pre key record: " + keyId);
}
return kyberPreKey;
}
@Override
public List<KyberPreKeyRecord> loadKyberPreKeys() {
final var sql = (
"""
SELECT p.serialized
FROM %s p
WHERE p.account_id_type = ?
"""
).formatted(TABLE_KYBER_PRE_KEY);
try (final var connection = database.getConnection()) {
try (final var statement = connection.prepareStatement(sql)) {
statement.setInt(1, accountIdType);
return Utils.executeQueryForStream(statement, this::getKyberPreKeyRecordFromResultSet).toList();
}
} catch (SQLException e) {
throw new RuntimeException("Failed read from kyber_pre_key store", e);
}
}
@Override
public List<KyberPreKeyRecord> loadLastResortKyberPreKeys() {
final var sql = (
"""
SELECT p.serialized
FROM %s p
WHERE p.account_id_type = ? AND p.is_last_resort = TRUE
"""
).formatted(TABLE_KYBER_PRE_KEY);
try (final var connection = database.getConnection()) {
try (final var statement = connection.prepareStatement(sql)) {
statement.setInt(1, accountIdType);
return Utils.executeQueryForStream(statement, this::getKyberPreKeyRecordFromResultSet).toList();
}
} catch (SQLException e) {
throw new RuntimeException("Failed read from kyber_pre_key store", e);
}
}
@Override
public void storeLastResortKyberPreKey(final int keyId, final KyberPreKeyRecord record) {
storeKyberPreKey(keyId, record, true);
}
@Override
public void storeKyberPreKey(final int keyId, final KyberPreKeyRecord record) {
storeKyberPreKey(keyId, record, false);
}
public void storeKyberPreKey(final int keyId, final KyberPreKeyRecord record, final boolean isLastResort) {
final var sql = (
"""
INSERT INTO %s (account_id_type, key_id, serialized, is_last_resort)
VALUES (?, ?, ?, ?)
"""
).formatted(TABLE_KYBER_PRE_KEY);
try (final var connection = database.getConnection()) {
try (final var statement = connection.prepareStatement(sql)) {
statement.setInt(1, accountIdType);
statement.setInt(2, keyId);
statement.setBytes(3, record.serialize());
statement.setBoolean(4, isLastResort);
statement.executeUpdate();
}
} catch (SQLException e) {
throw new RuntimeException("Failed update kyber_pre_key store", e);
}
}
@Override
public boolean containsKyberPreKey(final int keyId) {
return getPreKey(keyId) != null;
}
@Override
public void markKyberPreKeyUsed(final int keyId) {
final var sql = (
"""
DELETE FROM %s AS p
WHERE p.account_id_type = ? AND p.key_id = ? AND p.is_last_resort = FALSE
"""
).formatted(TABLE_KYBER_PRE_KEY);
try (final var connection = database.getConnection()) {
try (final var statement = connection.prepareStatement(sql)) {
statement.setInt(1, accountIdType);
statement.setInt(2, keyId);
statement.executeUpdate();
}
} catch (SQLException e) {
throw new RuntimeException("Failed update kyber_pre_key store", e);
}
}
@Override
public void removeKyberPreKey(final int keyId) {
final var sql = (
"""
DELETE FROM %s AS p
WHERE p.account_id_type = ? AND p.key_id = ?
"""
).formatted(TABLE_KYBER_PRE_KEY);
try (final var connection = database.getConnection()) {
try (final var statement = connection.prepareStatement(sql)) {
statement.setInt(1, accountIdType);
statement.setInt(2, keyId);
statement.executeUpdate();
}
} catch (SQLException e) {
throw new RuntimeException("Failed update kyber_pre_key store", e);
}
}
public void removeAllKyberPreKeys() {
final var sql = (
"""
DELETE FROM %s AS p
WHERE p.account_id_type = ?
"""
).formatted(TABLE_KYBER_PRE_KEY);
try (final var connection = database.getConnection()) {
try (final var statement = connection.prepareStatement(sql)) {
statement.setInt(1, accountIdType);
statement.executeUpdate();
}
} catch (SQLException e) {
throw new RuntimeException("Failed update kyber_pre_key store", e);
}
}
private KyberPreKeyRecord getPreKey(int keyId) {
final var sql = (
"""
SELECT p.serialized
FROM %s p
WHERE p.account_id_type = ? AND p.key_id = ?
"""
).formatted(TABLE_KYBER_PRE_KEY);
try (final var connection = database.getConnection()) {
try (final var statement = connection.prepareStatement(sql)) {
statement.setInt(1, accountIdType);
statement.setInt(2, keyId);
return Utils.executeQueryForOptional(statement, this::getKyberPreKeyRecordFromResultSet).orElse(null);
}
} catch (SQLException e) {
throw new RuntimeException("Failed read from kyber_pre_key store", e);
}
}
private KyberPreKeyRecord getKyberPreKeyRecordFromResultSet(ResultSet resultSet) throws SQLException {
try {
final var serialized = resultSet.getBytes("serialized");
return new KyberPreKeyRecord(serialized);
} catch (InvalidMessageException e) {
return null;
}
}
}

View file

@ -49,7 +49,7 @@ public class PreKeyStore implements org.signal.libsignal.protocol.state.PreKeySt
public PreKeyRecord loadPreKey(int preKeyId) throws InvalidKeyIdException { public PreKeyRecord loadPreKey(int preKeyId) throws InvalidKeyIdException {
final var preKey = getPreKey(preKeyId); final var preKey = getPreKey(preKeyId);
if (preKey == null) { if (preKey == null) {
throw new InvalidKeyIdException("No such signed pre key record: " + preKeyId); throw new InvalidKeyIdException("No such pre key record: " + preKeyId);
} }
return preKey; return preKey;
} }

View file

@ -14,6 +14,7 @@ import org.signal.libsignal.protocol.state.SessionRecord;
import org.signal.libsignal.protocol.state.SignedPreKeyRecord; import org.signal.libsignal.protocol.state.SignedPreKeyRecord;
import org.signal.libsignal.protocol.state.SignedPreKeyStore; import org.signal.libsignal.protocol.state.SignedPreKeyStore;
import org.whispersystems.signalservice.api.SignalServiceAccountDataStore; import org.whispersystems.signalservice.api.SignalServiceAccountDataStore;
import org.whispersystems.signalservice.api.SignalServiceKyberPreKeyStore;
import org.whispersystems.signalservice.api.SignalServiceSenderKeyStore; import org.whispersystems.signalservice.api.SignalServiceSenderKeyStore;
import org.whispersystems.signalservice.api.SignalServiceSessionStore; import org.whispersystems.signalservice.api.SignalServiceSessionStore;
import org.whispersystems.signalservice.api.push.DistributionId; import org.whispersystems.signalservice.api.push.DistributionId;
@ -28,6 +29,7 @@ public class SignalProtocolStore implements SignalServiceAccountDataStore {
private final PreKeyStore preKeyStore; private final PreKeyStore preKeyStore;
private final SignedPreKeyStore signedPreKeyStore; private final SignedPreKeyStore signedPreKeyStore;
private final SignalServiceKyberPreKeyStore kyberPreKeyStore;
private final SignalServiceSessionStore sessionStore; private final SignalServiceSessionStore sessionStore;
private final IdentityKeyStore identityKeyStore; private final IdentityKeyStore identityKeyStore;
private final SignalServiceSenderKeyStore senderKeyStore; private final SignalServiceSenderKeyStore senderKeyStore;
@ -36,6 +38,7 @@ public class SignalProtocolStore implements SignalServiceAccountDataStore {
public SignalProtocolStore( public SignalProtocolStore(
final PreKeyStore preKeyStore, final PreKeyStore preKeyStore,
final SignedPreKeyStore signedPreKeyStore, final SignedPreKeyStore signedPreKeyStore,
final SignalServiceKyberPreKeyStore kyberPreKeyStore,
final SignalServiceSessionStore sessionStore, final SignalServiceSessionStore sessionStore,
final IdentityKeyStore identityKeyStore, final IdentityKeyStore identityKeyStore,
final SignalServiceSenderKeyStore senderKeyStore, final SignalServiceSenderKeyStore senderKeyStore,
@ -43,6 +46,7 @@ public class SignalProtocolStore implements SignalServiceAccountDataStore {
) { ) {
this.preKeyStore = preKeyStore; this.preKeyStore = preKeyStore;
this.signedPreKeyStore = signedPreKeyStore; this.signedPreKeyStore = signedPreKeyStore;
this.kyberPreKeyStore = kyberPreKeyStore;
this.sessionStore = sessionStore; this.sessionStore = sessionStore;
this.identityKeyStore = identityKeyStore; this.identityKeyStore = identityKeyStore;
this.senderKeyStore = senderKeyStore; this.senderKeyStore = senderKeyStore;
@ -201,45 +205,41 @@ public class SignalProtocolStore implements SignalServiceAccountDataStore {
@Override @Override
public KyberPreKeyRecord loadKyberPreKey(final int kyberPreKeyId) throws InvalidKeyIdException { public KyberPreKeyRecord loadKyberPreKey(final int kyberPreKeyId) throws InvalidKeyIdException {
// TODO return kyberPreKeyStore.loadKyberPreKey(kyberPreKeyId);
throw new InvalidKeyIdException("Missing kyber prekey with ID: $kyberPreKeyId");
} }
@Override @Override
public List<KyberPreKeyRecord> loadKyberPreKeys() { public List<KyberPreKeyRecord> loadKyberPreKeys() {
// TODO return kyberPreKeyStore.loadKyberPreKeys();
return List.of();
} }
@Override @Override
public void storeKyberPreKey(final int kyberPreKeyId, final KyberPreKeyRecord record) { public void storeKyberPreKey(final int kyberPreKeyId, final KyberPreKeyRecord record) {
// TODO kyberPreKeyStore.storeKyberPreKey(kyberPreKeyId, record);
} }
@Override @Override
public boolean containsKyberPreKey(final int kyberPreKeyId) { public boolean containsKyberPreKey(final int kyberPreKeyId) {
// TODO return kyberPreKeyStore.containsKyberPreKey(kyberPreKeyId);
return false;
} }
@Override @Override
public void markKyberPreKeyUsed(final int kyberPreKeyId) { public void markKyberPreKeyUsed(final int kyberPreKeyId) {
// TODO kyberPreKeyStore.markKyberPreKeyUsed(kyberPreKeyId);
} }
@Override @Override
public List<KyberPreKeyRecord> loadLastResortKyberPreKeys() { public List<KyberPreKeyRecord> loadLastResortKyberPreKeys() {
// TODO return kyberPreKeyStore.loadLastResortKyberPreKeys();
return List.of();
} }
@Override @Override
public void removeKyberPreKey(final int i) { public void removeKyberPreKey(final int i) {
// TODO kyberPreKeyStore.removeKyberPreKey(i);
} }
@Override @Override
public void storeLastResortKyberPreKey(final int i, final KyberPreKeyRecord kyberPreKeyRecord) { public void storeLastResortKyberPreKey(final int i, final KyberPreKeyRecord kyberPreKeyRecord) {
// TODO kyberPreKeyStore.storeLastResortKyberPreKey(i, kyberPreKeyRecord);
} }
} }

View file

@ -406,9 +406,7 @@ public class SessionStore implements SignalServiceSessionStore {
} }
private static boolean isActive(SessionRecord record) { private static boolean isActive(SessionRecord record) {
return record != null return record != null && record.hasSenderChain();
&& record.hasSenderChain()
&& record.getSessionVersion() == CiphertextMessage.CURRENT_VERSION;
} }
record Key(ServiceId serviceId, int deviceId) {} record Key(ServiceId serviceId, int deviceId) {}

View file

@ -5,9 +5,11 @@ import org.signal.libsignal.protocol.IdentityKeyPair;
import org.signal.libsignal.protocol.InvalidKeyException; import org.signal.libsignal.protocol.InvalidKeyException;
import org.signal.libsignal.protocol.ecc.Curve; import org.signal.libsignal.protocol.ecc.Curve;
import org.signal.libsignal.protocol.ecc.ECPrivateKey; import org.signal.libsignal.protocol.ecc.ECPrivateKey;
import org.signal.libsignal.protocol.kem.KEMKeyPair;
import org.signal.libsignal.protocol.kem.KEMKeyType;
import org.signal.libsignal.protocol.state.KyberPreKeyRecord;
import org.signal.libsignal.protocol.state.PreKeyRecord; import org.signal.libsignal.protocol.state.PreKeyRecord;
import org.signal.libsignal.protocol.state.SignedPreKeyRecord; import org.signal.libsignal.protocol.state.SignedPreKeyRecord;
import org.signal.libsignal.protocol.util.Medium;
import org.signal.libsignal.zkgroup.InvalidInputException; import org.signal.libsignal.zkgroup.InvalidInputException;
import org.signal.libsignal.zkgroup.profiles.ProfileKey; import org.signal.libsignal.zkgroup.profiles.ProfileKey;
import org.whispersystems.signalservice.api.kbs.MasterKey; import org.whispersystems.signalservice.api.kbs.MasterKey;
@ -17,6 +19,8 @@ import java.util.ArrayList;
import java.util.Base64; import java.util.Base64;
import java.util.List; import java.util.List;
import static org.asamk.signal.manager.config.ServiceConfig.PREKEY_MAXIMUM_ID;
public class KeyUtils { public class KeyUtils {
private static final SecureRandom secureRandom = new SecureRandom(); private static final SecureRandom secureRandom = new SecureRandom();
@ -46,7 +50,7 @@ public class KeyUtils {
public static List<PreKeyRecord> generatePreKeyRecords(final int offset, final int batchSize) { public static List<PreKeyRecord> generatePreKeyRecords(final int offset, final int batchSize) {
var records = new ArrayList<PreKeyRecord>(batchSize); var records = new ArrayList<PreKeyRecord>(batchSize);
for (var i = 0; i < batchSize; i++) { for (var i = 0; i < batchSize; i++) {
var preKeyId = (offset + i) % Medium.MAX_VALUE; var preKeyId = (offset + i) % PREKEY_MAXIMUM_ID;
var keyPair = Curve.generateKeyPair(); var keyPair = Curve.generateKeyPair();
var record = new PreKeyRecord(preKeyId, keyPair); var record = new PreKeyRecord(preKeyId, keyPair);
@ -68,6 +72,24 @@ public class KeyUtils {
return new SignedPreKeyRecord(signedPreKeyId, System.currentTimeMillis(), keyPair, signature); return new SignedPreKeyRecord(signedPreKeyId, System.currentTimeMillis(), keyPair, signature);
} }
public static List<KyberPreKeyRecord> generateKyberPreKeyRecords(
final int offset, final int batchSize, final ECPrivateKey privateKey
) {
var records = new ArrayList<KyberPreKeyRecord>(batchSize);
for (var i = 0; i < batchSize; i++) {
var preKeyId = (offset + i) % PREKEY_MAXIMUM_ID;
records.add(generateKyberPreKeyRecord(preKeyId, privateKey));
}
return records;
}
public static KyberPreKeyRecord generateKyberPreKeyRecord(final int preKeyId, final ECPrivateKey privateKey) {
KEMKeyPair keyPair = KEMKeyPair.generate(KEMKeyType.KYBER_1024);
byte[] signature = privateKey.calculateSignature(keyPair.getPublicKey().serialize());
return new KyberPreKeyRecord(preKeyId, System.currentTimeMillis(), keyPair, signature);
}
public static ProfileKey createProfileKey() { public static ProfileKey createProfileKey() {
try { try {
return new ProfileKey(getSecretBytes(32)); return new ProfileKey(getSecretBytes(32));