Refactor ACI/PNI store handling

This commit is contained in:
AsamK 2023-06-18 14:44:57 +02:00
parent 306e38c9ee
commit 0ebfd989d1
6 changed files with 302 additions and 339 deletions

View file

@ -8,15 +8,17 @@ public class RenewSessionAction implements HandleAction {
private final RecipientId recipientId; private final RecipientId recipientId;
private final ServiceId serviceId; private final ServiceId serviceId;
private final ServiceId accountId;
public RenewSessionAction(final RecipientId recipientId, final ServiceId serviceId) { public RenewSessionAction(final RecipientId recipientId, final ServiceId serviceId, final ServiceId accountId) {
this.recipientId = recipientId; this.recipientId = recipientId;
this.serviceId = serviceId; this.serviceId = serviceId;
this.accountId = accountId;
} }
@Override @Override
public void execute(Context context) throws Throwable { public void execute(Context context) throws Throwable {
context.getAccount().getAciSessionStore().archiveSessions(serviceId); context.getAccount().getAccountData(accountId).getSessionStore().archiveSessions(serviceId);
if (!recipientId.equals(context.getAccount().getSelfRecipientId())) { if (!recipientId.equals(context.getAccount().getSelfRecipientId())) {
context.getSendHelper().sendNullMessage(recipientId); context.getSendHelper().sendNullMessage(recipientId);
} }

View file

@ -18,22 +18,25 @@ public class SendRetryMessageRequestAction implements HandleAction {
private final ServiceId serviceId; private final ServiceId serviceId;
private final ProtocolException protocolException; private final ProtocolException protocolException;
private final SignalServiceEnvelope envelope; private final SignalServiceEnvelope envelope;
private final ServiceId accountId;
public SendRetryMessageRequestAction( public SendRetryMessageRequestAction(
final RecipientId recipientId, final RecipientId recipientId,
final ServiceId serviceId, final ServiceId serviceId,
final ProtocolException protocolException, final ProtocolException protocolException,
final SignalServiceEnvelope envelope final SignalServiceEnvelope envelope,
final ServiceId accountId
) { ) {
this.recipientId = recipientId; this.recipientId = recipientId;
this.serviceId = serviceId; this.serviceId = serviceId;
this.protocolException = protocolException; this.protocolException = protocolException;
this.envelope = envelope; this.envelope = envelope;
this.accountId = accountId;
} }
@Override @Override
public void execute(Context context) throws Throwable { public void execute(Context context) throws Throwable {
context.getAccount().getAciSessionStore().archiveSessions(serviceId); context.getAccount().getAccountData(accountId).getSessionStore().archiveSessions(serviceId);
int senderDevice = protocolException.getSenderDevice(); int senderDevice = protocolException.getSenderDevice();
Optional<GroupId> groupId = protocolException.getGroupId().isPresent() ? Optional.of(GroupId.unknownVersion( Optional<GroupId> groupId = protocolException.getGroupId().isPresent() ? Optional.of(GroupId.unknownVersion(

View file

@ -181,12 +181,13 @@ public final class IncomingMessageHandler {
.contains(Profile.Capability.senderKey); .contains(Profile.Capability.senderKey);
final var isSelfSenderKeyCapable = selfProfile != null && selfProfile.getCapabilities() final var isSelfSenderKeyCapable = selfProfile != null && selfProfile.getCapabilities()
.contains(Profile.Capability.senderKey); .contains(Profile.Capability.senderKey);
final var destination = getDestination(envelope).serviceId();
if (!isSelf && isSenderSenderKeyCapable && isSelfSenderKeyCapable) { if (!isSelf && isSenderSenderKeyCapable && isSelfSenderKeyCapable) {
logger.debug("Received invalid message, requesting message resend."); logger.debug("Received invalid message, requesting message resend.");
actions.add(new SendRetryMessageRequestAction(sender, serviceId, e, envelope)); actions.add(new SendRetryMessageRequestAction(sender, serviceId, e, envelope, destination));
} else { } else {
logger.debug("Received invalid message, queuing renew session action."); logger.debug("Received invalid message, queuing renew session action.");
actions.add(new RenewSessionAction(sender, serviceId)); actions.add(new RenewSessionAction(sender, serviceId, destination));
} }
} else { } else {
logger.debug("Received invalid message from invalid sender: {}", e.getSender()); logger.debug("Received invalid message from invalid sender: {}", e.getSender());
@ -346,7 +347,12 @@ public final class IncomingMessageHandler {
senderDeviceId, senderDeviceId,
message.getTimestamp()); message.getTimestamp());
if (message.getDeviceId() == account.getDeviceId()) { if (message.getDeviceId() == account.getDeviceId()) {
handleDecryptionErrorMessage(actions, sender, senderServiceId, senderDeviceId, message); handleDecryptionErrorMessage(actions,
sender,
senderServiceId,
senderDeviceId,
message,
destination.serviceId());
} else { } else {
logger.debug("Request is for another one of our devices"); logger.debug("Request is for another one of our devices");
} }
@ -430,7 +436,8 @@ public final class IncomingMessageHandler {
final RecipientId sender, final RecipientId sender,
final ServiceId senderServiceId, final ServiceId senderServiceId,
final int senderDeviceId, final int senderDeviceId,
final DecryptionErrorMessage message final DecryptionErrorMessage message,
final ServiceId destination
) { ) {
final var logEntries = account.getMessageSendLogStore() final var logEntries = account.getMessageSendLogStore()
.findMessages(senderServiceId, .findMessages(senderServiceId,
@ -443,14 +450,14 @@ public final class IncomingMessageHandler {
} }
if (message.getRatchetKey().isPresent()) { if (message.getRatchetKey().isPresent()) {
if (account.getAciSessionStore() final var sessionStore = account.getAccountData(destination).getSessionStore();
.isCurrentRatchetKey(senderServiceId, senderDeviceId, message.getRatchetKey().get())) { if (sessionStore.isCurrentRatchetKey(senderServiceId, senderDeviceId, message.getRatchetKey().get())) {
if (logEntries.isEmpty()) { if (logEntries.isEmpty()) {
logger.debug("Renewing the session with sender"); logger.debug("Renewing the session with sender");
actions.add(new RenewSessionAction(sender, senderServiceId)); actions.add(new RenewSessionAction(sender, senderServiceId, destination));
} else { } else {
logger.trace("Archiving the session with sender, a resend message has already been queued"); logger.trace("Archiving the session with sender, a resend message has already been queued");
context.getAccount().getAciSessionStore().archiveSessions(senderServiceId); sessionStore.archiveSessions(senderServiceId);
} }
} }
return; return;
@ -806,9 +813,12 @@ public final class IncomingMessageHandler {
} }
} }
final var selfAddress = isSync ? source : destination;
final var conversationPartnerAddress = isSync ? destination : source; final var conversationPartnerAddress = isSync ? destination : source;
if (conversationPartnerAddress != null && message.isEndSession()) { if (conversationPartnerAddress != null && message.isEndSession()) {
account.getAciSessionStore().deleteAllSessions(conversationPartnerAddress.serviceId()); account.getAccountData(selfAddress.serviceId())
.getSessionStore()
.deleteAllSessions(conversationPartnerAddress.serviceId());
} }
if (message.isExpirationUpdate() || message.getBody().isPresent()) { if (message.isExpirationUpdate() || message.getBody().isPresent()) {
if (message.getGroupContext().isPresent()) { if (message.getGroupContext().isPresent()) {
@ -854,10 +864,12 @@ public final class IncomingMessageHandler {
if (message.getQuote().isPresent()) { if (message.getQuote().isPresent()) {
final var quote = message.getQuote().get(); final var quote = message.getQuote().get();
for (var quotedAttachment : quote.getAttachments()) { if (quote.getAttachments() != null) {
final var thumbnail = quotedAttachment.getThumbnail(); for (var quotedAttachment : quote.getAttachments()) {
if (thumbnail != null) { final var thumbnail = quotedAttachment.getThumbnail();
context.getAttachmentHelper().downloadAttachment(thumbnail); if (thumbnail != null) {
context.getAttachmentHelper().downloadAttachment(thumbnail);
}
} }
} }
} }
@ -972,7 +984,9 @@ public final class IncomingMessageHandler {
return new DeviceAddress(account.getSelfRecipientId(), account.getAci(), account.getDeviceId()); return new DeviceAddress(account.getSelfRecipientId(), account.getAci(), account.getDeviceId());
} }
final var address = addressOptional.get(); final var address = addressOptional.get();
return new DeviceAddress(context.getRecipientHelper().resolveRecipient(address), address.getServiceId(), 0); return new DeviceAddress(context.getRecipientHelper().resolveRecipient(address),
address.getServiceId(),
account.getDeviceId());
} }
private record DeviceAddress(RecipientId recipientId, ServiceId serviceId, int deviceId) {} private record DeviceAddress(RecipientId recipientId, ServiceId serviceId, int deviceId) {}

View file

@ -82,6 +82,7 @@ import org.whispersystems.signalservice.api.messages.SignalServiceReceiptMessage
import org.whispersystems.signalservice.api.messages.SignalServiceTypingMessage; import org.whispersystems.signalservice.api.messages.SignalServiceTypingMessage;
import org.whispersystems.signalservice.api.push.ACI; import org.whispersystems.signalservice.api.push.ACI;
import org.whispersystems.signalservice.api.push.ServiceId; import org.whispersystems.signalservice.api.push.ServiceId;
import org.whispersystems.signalservice.api.push.ServiceIdType;
import org.whispersystems.signalservice.api.util.DeviceNameUtil; import org.whispersystems.signalservice.api.util.DeviceNameUtil;
import org.whispersystems.signalservice.api.util.InvalidNumberException; import org.whispersystems.signalservice.api.util.InvalidNumberException;
import org.whispersystems.signalservice.api.util.PhoneNumberFormatter; import org.whispersystems.signalservice.api.util.PhoneNumberFormatter;
@ -183,8 +184,8 @@ public class ManagerImpl implements Manager {
}); });
disposable.add(account.getIdentityKeyStore().getIdentityChanges().subscribe(serviceId -> { disposable.add(account.getIdentityKeyStore().getIdentityChanges().subscribe(serviceId -> {
logger.trace("Archiving old sessions for {}", serviceId); logger.trace("Archiving old sessions for {}", serviceId);
account.getAciSessionStore().archiveSessions(serviceId); account.getAccountData(ServiceIdType.ACI).getSessionStore().archiveSessions(serviceId);
account.getPniSessionStore().archiveSessions(serviceId); account.getAccountData(ServiceIdType.PNI).getSessionStore().archiveSessions(serviceId);
account.getSenderKeyStore().deleteSharedWith(serviceId); account.getSenderKeyStore().deleteSharedWith(serviceId);
final var recipientId = account.getRecipientResolver().resolveRecipient(serviceId); final var recipientId = account.getRecipientResolver().resolveRecipient(serviceId);
final var profile = account.getProfileStore().getProfile(recipientId); final var profile = account.getProfileStore().getProfile(recipientId);
@ -775,7 +776,7 @@ public class ManagerImpl implements Manager {
.resolveRecipientAddress(recipientId) .resolveRecipientAddress(recipientId)
.serviceId(); .serviceId();
if (serviceId.isPresent()) { if (serviceId.isPresent()) {
account.getAciSessionStore().deleteAllSessions(serviceId.get()); account.getAccountData(ServiceIdType.ACI).getSessionStore().deleteAllSessions(serviceId.get());
} }
} }
} }

View file

@ -89,7 +89,6 @@ import java.nio.channels.ClosedChannelException;
import java.nio.channels.FileChannel; import java.nio.channels.FileChannel;
import java.nio.channels.FileLock; import java.nio.channels.FileLock;
import java.nio.file.Files; import java.nio.file.Files;
import java.security.SecureRandom;
import java.sql.Connection; import java.sql.Connection;
import java.sql.SQLException; import java.sql.SQLException;
import java.util.Base64; import java.util.Base64;
@ -136,36 +135,14 @@ public class SignalAccount implements Closeable {
private StorageKey storageKey; private StorageKey storageKey;
private long storageManifestVersion = -1; private long storageManifestVersion = -1;
private ProfileKey profileKey; private ProfileKey profileKey;
private int aciPreKeyIdOffset = 1;
private int aciNextSignedPreKeyId = 1;
private int pniPreKeyIdOffset = 1;
private int pniNextSignedPreKeyId = 1;
private int aciKyberPreKeyIdOffset = 1;
private int aciActiveLastResortKyberPreKeyId = -1;
private int pniKyberPreKeyIdOffset = 1;
private int pniActiveLastResortKyberPreKeyId = -1;
private IdentityKeyPair aciIdentityKeyPair;
private IdentityKeyPair pniIdentityKeyPair;
private int localRegistrationId;
private int localPniRegistrationId;
private Settings settings; private Settings settings;
private long lastReceiveTimestamp = 0; private long lastReceiveTimestamp = 0;
private boolean registered = false; private boolean registered = false;
private SignalProtocolStore aciSignalProtocolStore; private final AccountData aciAccountData = new AccountData(ServiceIdType.ACI);
private SignalProtocolStore pniSignalProtocolStore; private final AccountData pniAccountData = new AccountData(ServiceIdType.PNI);
private PreKeyStore aciPreKeyStore;
private SignedPreKeyStore aciSignedPreKeyStore;
private KyberPreKeyStore aciKyberPreKeyStore;
private PreKeyStore pniPreKeyStore;
private SignedPreKeyStore pniSignedPreKeyStore;
private KyberPreKeyStore pniKyberPreKeyStore;
private SessionStore aciSessionStore;
private SessionStore pniSessionStore;
private IdentityKeyStore identityKeyStore; private IdentityKeyStore identityKeyStore;
private SignalIdentityKeyStore aciIdentityKeyStore;
private SignalIdentityKeyStore pniIdentityKeyStore;
private SenderKeyStore senderKeyStore; private SenderKeyStore senderKeyStore;
private GroupStore groupStore; private GroupStore groupStore;
private RecipientStore recipientStore; private RecipientStore recipientStore;
@ -231,10 +208,10 @@ public class SignalAccount implements Closeable {
signalAccount.profileKey = profileKey; signalAccount.profileKey = profileKey;
signalAccount.dataPath = dataPath; signalAccount.dataPath = dataPath;
signalAccount.aciIdentityKeyPair = aciIdentityKey; signalAccount.aciAccountData.setIdentityKeyPair(aciIdentityKey);
signalAccount.pniIdentityKeyPair = pniIdentityKey; signalAccount.pniAccountData.setIdentityKeyPair(pniIdentityKey);
signalAccount.localRegistrationId = registrationId; signalAccount.aciAccountData.setLocalRegistrationId(registrationId);
signalAccount.localPniRegistrationId = pniRegistrationId; signalAccount.pniAccountData.setLocalRegistrationId(pniRegistrationId);
signalAccount.settings = settings; signalAccount.settings = settings;
signalAccount.configurationStore = new ConfigurationStore(signalAccount::saveConfigurationStore); signalAccount.configurationStore = new ConfigurationStore(signalAccount::saveConfigurationStore);
@ -296,8 +273,8 @@ public class SignalAccount implements Closeable {
profileKey); profileKey);
signalAccount.getRecipientTrustedResolver() signalAccount.getRecipientTrustedResolver()
.resolveSelfRecipientTrusted(signalAccount.getSelfRecipientAddress()); .resolveSelfRecipientTrusted(signalAccount.getSelfRecipientAddress());
signalAccount.getAciSessionStore().archiveAllSessions(); signalAccount.aciAccountData.getSessionStore().archiveAllSessions();
signalAccount.getPniSessionStore().archiveAllSessions(); signalAccount.pniAccountData.getSessionStore().archiveAllSessions();
signalAccount.getSenderKeyStore().deleteAll(); signalAccount.getSenderKeyStore().deleteAll();
signalAccount.clearAllPreKeys(); signalAccount.clearAllPreKeys();
return signalAccount; return signalAccount;
@ -308,16 +285,17 @@ public class SignalAccount implements Closeable {
} }
private void clearAllPreKeys() { private void clearAllPreKeys() {
resetPreKeyOffsets(ServiceIdType.ACI); clearAllPreKeys(ServiceIdType.ACI);
resetPreKeyOffsets(ServiceIdType.PNI); clearAllPreKeys(ServiceIdType.PNI);
resetKyberPreKeyOffsets(ServiceIdType.ACI); }
resetKyberPreKeyOffsets(ServiceIdType.PNI);
this.getAciPreKeyStore().removeAllPreKeys(); private void clearAllPreKeys(ServiceIdType serviceIdType) {
this.getAciSignedPreKeyStore().removeAllSignedPreKeys(); final var accountData = getAccountData(serviceIdType);
this.getAciKyberPreKeyStore().removeAllKyberPreKeys(); resetPreKeyOffsets(serviceIdType);
this.getPniPreKeyStore().removeAllPreKeys(); resetKyberPreKeyOffsets(serviceIdType);
this.getPniSignedPreKeyStore().removeAllSignedPreKeys(); accountData.getPreKeyStore().removeAllPreKeys();
this.getPniKyberPreKeyStore().removeAllKyberPreKeys(); accountData.getSignedPreKeyStore().removeAllSignedPreKeys();
accountData.getKyberPreKeyStore().removeAllKyberPreKeys();
save(); save();
} }
@ -347,8 +325,8 @@ public class SignalAccount implements Closeable {
signalAccount.dataPath = dataPath; signalAccount.dataPath = dataPath;
signalAccount.accountPath = accountPath; signalAccount.accountPath = accountPath;
signalAccount.serviceEnvironment = serviceEnvironment; signalAccount.serviceEnvironment = serviceEnvironment;
signalAccount.localRegistrationId = registrationId; signalAccount.aciAccountData.setLocalRegistrationId(registrationId);
signalAccount.localPniRegistrationId = pniRegistrationId; signalAccount.pniAccountData.setLocalRegistrationId(pniRegistrationId);
signalAccount.settings = settings; signalAccount.settings = settings;
signalAccount.setProvisioningData(number, signalAccount.setProvisioningData(number,
aci, aci,
@ -391,8 +369,8 @@ public class SignalAccount implements Closeable {
getProfileStore().storeSelfProfileKey(getSelfRecipientId(), getProfileKey()); getProfileStore().storeSelfProfileKey(getSelfRecipientId(), getProfileKey());
this.encryptedDeviceName = encryptedDeviceName; this.encryptedDeviceName = encryptedDeviceName;
this.deviceId = deviceId; this.deviceId = deviceId;
this.aciIdentityKeyPair = aciIdentity; this.aciAccountData.setIdentityKeyPair(aciIdentity);
this.pniIdentityKeyPair = pniIdentity; this.pniAccountData.setIdentityKeyPair(pniIdentity);
this.registered = true; this.registered = true;
this.isMultiDevice = true; this.isMultiDevice = true;
this.lastReceiveTimestamp = 0; this.lastReceiveTimestamp = 0;
@ -400,13 +378,9 @@ public class SignalAccount implements Closeable {
this.storageManifestVersion = -1; this.storageManifestVersion = -1;
this.setStorageManifest(null); this.setStorageManifest(null);
this.storageKey = null; this.storageKey = null;
final var aciPublicKey = getAciIdentityKeyPair().getPublicKey(); trustSelfIdentity(ServiceIdType.ACI);
getIdentityKeyStore().saveIdentity(getAci(), aciPublicKey);
getIdentityKeyStore().setIdentityTrustLevel(getAci(), aciPublicKey, TrustLevel.TRUSTED_VERIFIED);
if (getPniIdentityKeyPair() != null) { if (getPniIdentityKeyPair() != null) {
final var pniPublicKey = getPniIdentityKeyPair().getPublicKey(); trustSelfIdentity(ServiceIdType.PNI);
getIdentityKeyStore().saveIdentity(getPni(), pniPublicKey);
getIdentityKeyStore().setIdentityTrustLevel(getPni(), pniPublicKey, TrustLevel.TRUSTED_VERIFIED);
} }
} }
@ -438,8 +412,8 @@ public class SignalAccount implements Closeable {
getMessageCache().deleteMessages(recipientId); getMessageCache().deleteMessages(recipientId);
if (recipientAddress.serviceId().isPresent()) { if (recipientAddress.serviceId().isPresent()) {
final var serviceId = recipientAddress.serviceId().get(); final var serviceId = recipientAddress.serviceId().get();
getAciSessionStore().deleteAllSessions(serviceId); aciAccountData.getSessionStore().deleteAllSessions(serviceId);
getPniSessionStore().deleteAllSessions(serviceId); pniAccountData.getSessionStore().deleteAllSessions(serviceId);
getIdentityKeyStore().deleteIdentity(serviceId); getIdentityKeyStore().deleteIdentity(serviceId);
getSenderKeyStore().deleteAll(serviceId); getSenderKeyStore().deleteAll(serviceId);
} }
@ -595,9 +569,9 @@ public class SignalAccount implements Closeable {
registrationId = rootNode.get("registrationId").asInt(); registrationId = rootNode.get("registrationId").asInt();
} }
if (rootNode.hasNonNull("pniRegistrationId")) { if (rootNode.hasNonNull("pniRegistrationId")) {
localPniRegistrationId = rootNode.get("pniRegistrationId").asInt(); pniAccountData.setLocalRegistrationId(rootNode.get("pniRegistrationId").asInt());
} else { } else {
localPniRegistrationId = KeyHelper.generateRegistrationId(false); pniAccountData.setLocalRegistrationId(KeyHelper.generateRegistrationId(false));
} }
IdentityKeyPair aciIdentityKeyPair = null; IdentityKeyPair aciIdentityKeyPair = null;
if (rootNode.hasNonNull("identityPrivateKey") && rootNode.hasNonNull("identityKey")) { if (rootNode.hasNonNull("identityPrivateKey") && rootNode.hasNonNull("identityKey")) {
@ -608,7 +582,7 @@ public class SignalAccount implements Closeable {
if (rootNode.hasNonNull("pniIdentityPrivateKey") && rootNode.hasNonNull("pniIdentityKey")) { if (rootNode.hasNonNull("pniIdentityPrivateKey") && rootNode.hasNonNull("pniIdentityKey")) {
final var publicKeyBytes = Base64.getDecoder().decode(rootNode.get("pniIdentityKey").asText()); final var publicKeyBytes = Base64.getDecoder().decode(rootNode.get("pniIdentityKey").asText());
final var privateKeyBytes = Base64.getDecoder().decode(rootNode.get("pniIdentityPrivateKey").asText()); final var privateKeyBytes = Base64.getDecoder().decode(rootNode.get("pniIdentityPrivateKey").asText());
pniIdentityKeyPair = KeyUtils.getIdentityKeyPair(publicKeyBytes, privateKeyBytes); pniAccountData.setIdentityKeyPair(KeyUtils.getIdentityKeyPair(publicKeyBytes, privateKeyBytes));
} }
if (rootNode.hasNonNull("registrationLockPin")) { if (rootNode.hasNonNull("registrationLockPin")) {
@ -624,44 +598,46 @@ public class SignalAccount implements Closeable {
storageManifestVersion = rootNode.get("storageManifestVersion").asLong(); storageManifestVersion = rootNode.get("storageManifestVersion").asLong();
} }
if (rootNode.hasNonNull("preKeyIdOffset")) { if (rootNode.hasNonNull("preKeyIdOffset")) {
aciPreKeyIdOffset = rootNode.get("preKeyIdOffset").asInt(1); aciAccountData.preKeyMetadata.preKeyIdOffset = rootNode.get("preKeyIdOffset").asInt(1);
} else { } else {
aciPreKeyIdOffset = getRandomPreKeyIdOffset(); aciAccountData.preKeyMetadata.preKeyIdOffset = getRandomPreKeyIdOffset();
} }
if (rootNode.hasNonNull("nextSignedPreKeyId")) { if (rootNode.hasNonNull("nextSignedPreKeyId")) {
aciNextSignedPreKeyId = rootNode.get("nextSignedPreKeyId").asInt(1); aciAccountData.preKeyMetadata.nextSignedPreKeyId = rootNode.get("nextSignedPreKeyId").asInt(1);
} else { } else {
aciNextSignedPreKeyId = getRandomPreKeyIdOffset(); aciAccountData.preKeyMetadata.nextSignedPreKeyId = getRandomPreKeyIdOffset();
} }
if (rootNode.hasNonNull("pniPreKeyIdOffset")) { if (rootNode.hasNonNull("pniPreKeyIdOffset")) {
pniPreKeyIdOffset = rootNode.get("pniPreKeyIdOffset").asInt(1); pniAccountData.preKeyMetadata.preKeyIdOffset = rootNode.get("pniPreKeyIdOffset").asInt(1);
} else { } else {
pniPreKeyIdOffset = getRandomPreKeyIdOffset(); pniAccountData.preKeyMetadata.preKeyIdOffset = getRandomPreKeyIdOffset();
} }
if (rootNode.hasNonNull("pniNextSignedPreKeyId")) { if (rootNode.hasNonNull("pniNextSignedPreKeyId")) {
pniNextSignedPreKeyId = rootNode.get("pniNextSignedPreKeyId").asInt(1); pniAccountData.preKeyMetadata.nextSignedPreKeyId = rootNode.get("pniNextSignedPreKeyId").asInt(1);
} else { } else {
pniNextSignedPreKeyId = getRandomPreKeyIdOffset(); pniAccountData.preKeyMetadata.nextSignedPreKeyId = getRandomPreKeyIdOffset();
} }
if (rootNode.hasNonNull("kyberPreKeyIdOffset")) { if (rootNode.hasNonNull("kyberPreKeyIdOffset")) {
aciKyberPreKeyIdOffset = rootNode.get("kyberPreKeyIdOffset").asInt(1); aciAccountData.preKeyMetadata.kyberPreKeyIdOffset = rootNode.get("kyberPreKeyIdOffset").asInt(1);
} else { } else {
aciKyberPreKeyIdOffset = getRandomPreKeyIdOffset(); aciAccountData.preKeyMetadata.kyberPreKeyIdOffset = getRandomPreKeyIdOffset();
} }
if (rootNode.hasNonNull("activeLastResortKyberPreKeyId")) { if (rootNode.hasNonNull("activeLastResortKyberPreKeyId")) {
aciActiveLastResortKyberPreKeyId = rootNode.get("activeLastResortKyberPreKeyId").asInt(-1); aciAccountData.preKeyMetadata.activeLastResortKyberPreKeyId = rootNode.get("activeLastResortKyberPreKeyId")
.asInt(-1);
} else { } else {
aciActiveLastResortKyberPreKeyId = -1; aciAccountData.preKeyMetadata.activeLastResortKyberPreKeyId = -1;
} }
if (rootNode.hasNonNull("pniKyberPreKeyIdOffset")) { if (rootNode.hasNonNull("pniKyberPreKeyIdOffset")) {
pniKyberPreKeyIdOffset = rootNode.get("pniKyberPreKeyIdOffset").asInt(1); pniAccountData.preKeyMetadata.kyberPreKeyIdOffset = rootNode.get("pniKyberPreKeyIdOffset").asInt(1);
} else { } else {
pniKyberPreKeyIdOffset = getRandomPreKeyIdOffset(); pniAccountData.preKeyMetadata.kyberPreKeyIdOffset = getRandomPreKeyIdOffset();
} }
if (rootNode.hasNonNull("pniActiveLastResortKyberPreKeyId")) { if (rootNode.hasNonNull("pniActiveLastResortKyberPreKeyId")) {
pniActiveLastResortKyberPreKeyId = rootNode.get("pniActiveLastResortKyberPreKeyId").asInt(-1); pniAccountData.preKeyMetadata.activeLastResortKyberPreKeyId = rootNode.get(
"pniActiveLastResortKyberPreKeyId").asInt(-1);
} else { } else {
pniActiveLastResortKyberPreKeyId = -1; pniAccountData.preKeyMetadata.activeLastResortKyberPreKeyId = -1;
} }
if (rootNode.hasNonNull("profileKey")) { if (rootNode.hasNonNull("profileKey")) {
try { try {
@ -687,22 +663,22 @@ public class SignalAccount implements Closeable {
} }
final var legacyAciPreKeysPath = getAciPreKeysPath(dataPath, accountPath); final var legacyAciPreKeysPath = getAciPreKeysPath(dataPath, accountPath);
if (legacyAciPreKeysPath.exists()) { if (legacyAciPreKeysPath.exists()) {
LegacyPreKeyStore.migrate(legacyAciPreKeysPath, getAciPreKeyStore()); LegacyPreKeyStore.migrate(legacyAciPreKeysPath, aciAccountData.getPreKeyStore());
migratedLegacyConfig = true; migratedLegacyConfig = true;
} }
final var legacyPniPreKeysPath = getPniPreKeysPath(dataPath, accountPath); final var legacyPniPreKeysPath = getPniPreKeysPath(dataPath, accountPath);
if (legacyPniPreKeysPath.exists()) { if (legacyPniPreKeysPath.exists()) {
LegacyPreKeyStore.migrate(legacyPniPreKeysPath, getPniPreKeyStore()); LegacyPreKeyStore.migrate(legacyPniPreKeysPath, pniAccountData.getPreKeyStore());
migratedLegacyConfig = true; migratedLegacyConfig = true;
} }
final var legacyAciSignedPreKeysPath = getAciSignedPreKeysPath(dataPath, accountPath); final var legacyAciSignedPreKeysPath = getAciSignedPreKeysPath(dataPath, accountPath);
if (legacyAciSignedPreKeysPath.exists()) { if (legacyAciSignedPreKeysPath.exists()) {
LegacySignedPreKeyStore.migrate(legacyAciSignedPreKeysPath, getAciSignedPreKeyStore()); LegacySignedPreKeyStore.migrate(legacyAciSignedPreKeysPath, aciAccountData.getSignedPreKeyStore());
migratedLegacyConfig = true; migratedLegacyConfig = true;
} }
final var legacyPniSignedPreKeysPath = getPniSignedPreKeysPath(dataPath, accountPath); final var legacyPniSignedPreKeysPath = getPniSignedPreKeysPath(dataPath, accountPath);
if (legacyPniSignedPreKeysPath.exists()) { if (legacyPniSignedPreKeysPath.exists()) {
LegacySignedPreKeyStore.migrate(legacyPniSignedPreKeysPath, getPniSignedPreKeyStore()); LegacySignedPreKeyStore.migrate(legacyPniSignedPreKeysPath, pniAccountData.getSignedPreKeyStore());
migratedLegacyConfig = true; migratedLegacyConfig = true;
} }
final var legacySessionsPath = getSessionsPath(dataPath, accountPath); final var legacySessionsPath = getSessionsPath(dataPath, accountPath);
@ -710,7 +686,7 @@ public class SignalAccount implements Closeable {
LegacySessionStore.migrate(legacySessionsPath, LegacySessionStore.migrate(legacySessionsPath,
getRecipientResolver(), getRecipientResolver(),
getRecipientAddressResolver(), getRecipientAddressResolver(),
getAciSessionStore()); aciAccountData.getSessionStore());
migratedLegacyConfig = true; migratedLegacyConfig = true;
} }
final var legacyIdentitiesPath = getIdentitiesPath(dataPath, accountPath); final var legacyIdentitiesPath = getIdentitiesPath(dataPath, accountPath);
@ -731,8 +707,8 @@ public class SignalAccount implements Closeable {
migratedLegacyConfig = true; migratedLegacyConfig = true;
} }
this.aciIdentityKeyPair = aciIdentityKeyPair; this.aciAccountData.setIdentityKeyPair(aciIdentityKeyPair);
this.localRegistrationId = registrationId; this.aciAccountData.setLocalRegistrationId(registrationId);
migratedLegacyConfig = loadLegacyStores(rootNode, legacySignalProtocolStore) || migratedLegacyConfig; migratedLegacyConfig = loadLegacyStores(rootNode, legacySignalProtocolStore) || migratedLegacyConfig;
@ -805,7 +781,7 @@ public class SignalAccount implements Closeable {
logger.debug("Migrating legacy pre key store."); logger.debug("Migrating legacy pre key store.");
for (var entry : legacySignalProtocolStore.getLegacyPreKeyStore().getPreKeys().entrySet()) { for (var entry : legacySignalProtocolStore.getLegacyPreKeyStore().getPreKeys().entrySet()) {
try { try {
getAciPreKeyStore().storePreKey(entry.getKey(), new PreKeyRecord(entry.getValue())); aciAccountData.getPreKeyStore().storePreKey(entry.getKey(), new PreKeyRecord(entry.getValue()));
} catch (InvalidMessageException e) { } catch (InvalidMessageException e) {
logger.warn("Failed to migrate pre key, ignoring", e); logger.warn("Failed to migrate pre key, ignoring", e);
} }
@ -817,8 +793,8 @@ public class SignalAccount implements Closeable {
logger.debug("Migrating legacy signed pre key store."); logger.debug("Migrating legacy signed pre key store.");
for (var entry : legacySignalProtocolStore.getLegacySignedPreKeyStore().getSignedPreKeys().entrySet()) { for (var entry : legacySignalProtocolStore.getLegacySignedPreKeyStore().getSignedPreKeys().entrySet()) {
try { try {
getAciSignedPreKeyStore().storeSignedPreKey(entry.getKey(), aciAccountData.getSignedPreKeyStore()
new SignedPreKeyRecord(entry.getValue())); .storeSignedPreKey(entry.getKey(), new SignedPreKeyRecord(entry.getValue()));
} catch (InvalidMessageException e) { } catch (InvalidMessageException e) {
logger.warn("Failed to migrate signed pre key, ignoring", e); logger.warn("Failed to migrate signed pre key, ignoring", e);
} }
@ -830,8 +806,9 @@ public class SignalAccount implements Closeable {
logger.debug("Migrating legacy session store."); logger.debug("Migrating legacy session store.");
for (var session : legacySignalProtocolStore.getLegacySessionStore().getSessions()) { for (var session : legacySignalProtocolStore.getLegacySessionStore().getSessions()) {
try { try {
getAciSessionStore().storeSession(new SignalProtocolAddress(session.address.getIdentifier(), aciAccountData.getSessionStore()
session.deviceId), new SessionRecord(session.sessionRecord)); .storeSession(new SignalProtocolAddress(session.address.getIdentifier(), session.deviceId),
new SessionRecord(session.sessionRecord));
} catch (Exception e) { } catch (Exception e) {
logger.warn("Failed to migrate session, ignoring", e); logger.warn("Failed to migrate session, ignoring", e);
} }
@ -981,35 +958,44 @@ public class SignalAccount implements Closeable {
.put("isMultiDevice", isMultiDevice) .put("isMultiDevice", isMultiDevice)
.put("lastReceiveTimestamp", lastReceiveTimestamp) .put("lastReceiveTimestamp", lastReceiveTimestamp)
.put("password", password) .put("password", password)
.put("registrationId", localRegistrationId) .put("registrationId", aciAccountData.getLocalRegistrationId())
.put("pniRegistrationId", localPniRegistrationId) .put("pniRegistrationId", pniAccountData.getLocalRegistrationId())
.put("identityPrivateKey", .put("identityPrivateKey",
Base64.getEncoder().encodeToString(aciIdentityKeyPair.getPrivateKey().serialize())) Base64.getEncoder()
.encodeToString(aciAccountData.getIdentityKeyPair().getPrivateKey().serialize()))
.put("identityKey", .put("identityKey",
Base64.getEncoder().encodeToString(aciIdentityKeyPair.getPublicKey().serialize())) Base64.getEncoder()
.encodeToString(aciAccountData.getIdentityKeyPair().getPublicKey().serialize()))
.put("pniIdentityPrivateKey", .put("pniIdentityPrivateKey",
pniIdentityKeyPair == null pniAccountData.getIdentityKeyPair() == null
? null ? null
: Base64.getEncoder() : Base64.getEncoder()
.encodeToString(pniIdentityKeyPair.getPrivateKey().serialize())) .encodeToString(pniAccountData.getIdentityKeyPair()
.getPrivateKey()
.serialize()))
.put("pniIdentityKey", .put("pniIdentityKey",
pniIdentityKeyPair == null pniAccountData.getIdentityKeyPair() == null
? null ? null
: Base64.getEncoder().encodeToString(pniIdentityKeyPair.getPublicKey().serialize())) : Base64.getEncoder()
.encodeToString(pniAccountData.getIdentityKeyPair()
.getPublicKey()
.serialize()))
.put("registrationLockPin", registrationLockPin) .put("registrationLockPin", registrationLockPin)
.put("pinMasterKey", .put("pinMasterKey",
pinMasterKey == null ? null : Base64.getEncoder().encodeToString(pinMasterKey.serialize())) pinMasterKey == null ? null : Base64.getEncoder().encodeToString(pinMasterKey.serialize()))
.put("storageKey", .put("storageKey",
storageKey == null ? null : Base64.getEncoder().encodeToString(storageKey.serialize())) storageKey == null ? null : Base64.getEncoder().encodeToString(storageKey.serialize()))
.put("storageManifestVersion", storageManifestVersion == -1 ? null : storageManifestVersion) .put("storageManifestVersion", storageManifestVersion == -1 ? null : storageManifestVersion)
.put("preKeyIdOffset", aciPreKeyIdOffset) .put("preKeyIdOffset", aciAccountData.getPreKeyMetadata().preKeyIdOffset)
.put("nextSignedPreKeyId", aciNextSignedPreKeyId) .put("nextSignedPreKeyId", aciAccountData.getPreKeyMetadata().nextSignedPreKeyId)
.put("pniPreKeyIdOffset", pniPreKeyIdOffset) .put("pniPreKeyIdOffset", pniAccountData.getPreKeyMetadata().preKeyIdOffset)
.put("pniNextSignedPreKeyId", pniNextSignedPreKeyId) .put("pniNextSignedPreKeyId", pniAccountData.getPreKeyMetadata().nextSignedPreKeyId)
.put("kyberPreKeyIdOffset", aciKyberPreKeyIdOffset) .put("kyberPreKeyIdOffset", aciAccountData.getPreKeyMetadata().kyberPreKeyIdOffset)
.put("activeLastResortKyberPreKeyId", aciActiveLastResortKyberPreKeyId) .put("activeLastResortKyberPreKeyId",
.put("pniKyberPreKeyIdOffset", pniKyberPreKeyIdOffset) aciAccountData.getPreKeyMetadata().activeLastResortKyberPreKeyId)
.put("pniActiveLastResortKyberPreKeyId", pniActiveLastResortKyberPreKeyId) .put("pniKyberPreKeyIdOffset", pniAccountData.getPreKeyMetadata().kyberPreKeyIdOffset)
.put("pniActiveLastResortKyberPreKeyId",
pniAccountData.getPreKeyMetadata().activeLastResortKyberPreKeyId)
.put("profileKey", .put("profileKey",
profileKey == null ? null : Base64.getEncoder().encodeToString(profileKey.serialize())) profileKey == null ? null : Base64.getEncoder().encodeToString(profileKey.serialize()))
.put("registered", registered) .put("registered", registered)
@ -1054,154 +1040,79 @@ public class SignalAccount implements Closeable {
} }
public void resetPreKeyOffsets(final ServiceIdType serviceIdType) { public void resetPreKeyOffsets(final ServiceIdType serviceIdType) {
if (serviceIdType.equals(ServiceIdType.ACI)) { final var preKeyMetadata = getAccountData(serviceIdType).getPreKeyMetadata();
this.aciPreKeyIdOffset = getRandomPreKeyIdOffset(); preKeyMetadata.preKeyIdOffset = getRandomPreKeyIdOffset();
this.aciNextSignedPreKeyId = getRandomPreKeyIdOffset(); preKeyMetadata.nextSignedPreKeyId = getRandomPreKeyIdOffset();
} else {
this.pniPreKeyIdOffset = getRandomPreKeyIdOffset();
this.pniNextSignedPreKeyId = getRandomPreKeyIdOffset();
}
save(); save();
} }
private static int getRandomPreKeyIdOffset() { private static int getRandomPreKeyIdOffset() {
return new SecureRandom().nextInt(PREKEY_MAXIMUM_ID); return KeyUtils.getRandomInt(PREKEY_MAXIMUM_ID);
} }
public void addPreKeys(ServiceIdType serviceIdType, List<PreKeyRecord> records) { public void addPreKeys(ServiceIdType serviceIdType, List<PreKeyRecord> records) {
if (serviceIdType.equals(ServiceIdType.ACI)) { final var accountData = getAccountData(serviceIdType);
addAciPreKeys(records); final var preKeyMetadata = accountData.getPreKeyMetadata();
} else {
addPniPreKeys(records);
}
}
private void addAciPreKeys(List<PreKeyRecord> records) {
for (var record : records) { for (var record : records) {
if (aciPreKeyIdOffset != record.getId()) { if (preKeyMetadata.preKeyIdOffset != record.getId()) {
logger.error("Invalid pre key id {}, expected {}", record.getId(), aciPreKeyIdOffset); logger.error("Invalid pre key id {}, expected {}", record.getId(), preKeyMetadata.preKeyIdOffset);
throw new AssertionError("Invalid pre key id"); throw new AssertionError("Invalid pre key id");
} }
getAciPreKeyStore().storePreKey(record.getId(), record); accountData.getPreKeyStore().storePreKey(record.getId(), record);
aciPreKeyIdOffset = (aciPreKeyIdOffset + 1) % PREKEY_MAXIMUM_ID; preKeyMetadata.preKeyIdOffset = (preKeyMetadata.preKeyIdOffset + 1) % PREKEY_MAXIMUM_ID;
}
save();
}
private void addPniPreKeys(List<PreKeyRecord> records) {
for (var record : records) {
if (pniPreKeyIdOffset != record.getId()) {
logger.error("Invalid pre key id {}, expected {}", record.getId(), pniPreKeyIdOffset);
throw new AssertionError("Invalid pre key id");
}
getPniPreKeyStore().storePreKey(record.getId(), record);
pniPreKeyIdOffset = (pniPreKeyIdOffset + 1) % PREKEY_MAXIMUM_ID;
} }
save(); save();
} }
public void addSignedPreKey(ServiceIdType serviceIdType, SignedPreKeyRecord record) { public void addSignedPreKey(ServiceIdType serviceIdType, SignedPreKeyRecord record) {
if (serviceIdType.equals(ServiceIdType.ACI)) { final var accountData = getAccountData(serviceIdType);
addAciSignedPreKey(record); final var preKeyMetadata = accountData.getPreKeyMetadata();
} else { if (preKeyMetadata.nextSignedPreKeyId != record.getId()) {
addPniSignedPreKey(record); logger.error("Invalid signed pre key id {}, expected {}",
} record.getId(),
} preKeyMetadata.nextSignedPreKeyId);
public void addAciSignedPreKey(SignedPreKeyRecord record) {
if (aciNextSignedPreKeyId != record.getId()) {
logger.error("Invalid signed pre key id {}, expected {}", record.getId(), aciNextSignedPreKeyId);
throw new AssertionError("Invalid signed pre key id"); throw new AssertionError("Invalid signed pre key id");
} }
getAciSignedPreKeyStore().storeSignedPreKey(record.getId(), record); accountData.getSignedPreKeyStore().storeSignedPreKey(record.getId(), record);
aciNextSignedPreKeyId = (aciNextSignedPreKeyId + 1) % PREKEY_MAXIMUM_ID; preKeyMetadata.nextSignedPreKeyId = (preKeyMetadata.nextSignedPreKeyId + 1) % PREKEY_MAXIMUM_ID;
save();
}
public void addPniSignedPreKey(SignedPreKeyRecord record) {
if (pniNextSignedPreKeyId != record.getId()) {
logger.error("Invalid signed pre key id {}, expected {}", record.getId(), pniNextSignedPreKeyId);
throw new AssertionError("Invalid signed pre key id");
}
getPniSignedPreKeyStore().storeSignedPreKey(record.getId(), record);
pniNextSignedPreKeyId = (pniNextSignedPreKeyId + 1) % PREKEY_MAXIMUM_ID;
save(); save();
} }
public void resetKyberPreKeyOffsets(final ServiceIdType serviceIdType) { public void resetKyberPreKeyOffsets(final ServiceIdType serviceIdType) {
if (serviceIdType.equals(ServiceIdType.ACI)) { final var preKeyMetadata = getAccountData(serviceIdType).getPreKeyMetadata();
this.aciKyberPreKeyIdOffset = getRandomPreKeyIdOffset(); preKeyMetadata.kyberPreKeyIdOffset = getRandomPreKeyIdOffset();
this.aciActiveLastResortKyberPreKeyId = -1; preKeyMetadata.activeLastResortKyberPreKeyId = -1;
} else {
this.pniKyberPreKeyIdOffset = getRandomPreKeyIdOffset();
this.pniActiveLastResortKyberPreKeyId = -1;
}
save(); save();
} }
public void addKyberPreKeys(ServiceIdType serviceIdType, List<KyberPreKeyRecord> records) { public void addKyberPreKeys(ServiceIdType serviceIdType, List<KyberPreKeyRecord> records) {
if (serviceIdType.equals(ServiceIdType.ACI)) { final var accountData = getAccountData(serviceIdType);
addAciKyberPreKeys(records); final var preKeyMetadata = accountData.getPreKeyMetadata();
} else {
addPniKyberPreKeys(records);
}
}
private void addAciKyberPreKeys(List<KyberPreKeyRecord> records) {
for (var record : records) { for (var record : records) {
if (aciKyberPreKeyIdOffset != record.getId()) { if (preKeyMetadata.kyberPreKeyIdOffset != record.getId()) {
logger.error("Invalid kyber pre key id {}, expected {}", record.getId(), aciKyberPreKeyIdOffset); logger.error("Invalid kyber pre key id {}, expected {}",
record.getId(),
preKeyMetadata.kyberPreKeyIdOffset);
throw new AssertionError("Invalid kyber pre key id"); throw new AssertionError("Invalid kyber pre key id");
} }
getAciKyberPreKeyStore().storeKyberPreKey(record.getId(), record); accountData.getKyberPreKeyStore().storeKyberPreKey(record.getId(), record);
aciKyberPreKeyIdOffset = (aciKyberPreKeyIdOffset + 1) % PREKEY_MAXIMUM_ID; preKeyMetadata.kyberPreKeyIdOffset = (preKeyMetadata.kyberPreKeyIdOffset + 1) % PREKEY_MAXIMUM_ID;
}
save();
}
private void addPniKyberPreKeys(List<KyberPreKeyRecord> records) {
for (var record : records) {
if (pniKyberPreKeyIdOffset != record.getId()) {
logger.error("Invalid kyber pre key id {}, expected {}", record.getId(), pniKyberPreKeyIdOffset);
throw new AssertionError("Invalid kyber pre key id");
}
getPniKyberPreKeyStore().storeKyberPreKey(record.getId(), record);
pniKyberPreKeyIdOffset = (pniKyberPreKeyIdOffset + 1) % PREKEY_MAXIMUM_ID;
} }
save(); save();
} }
public void addLastResortKyberPreKey(ServiceIdType serviceIdType, KyberPreKeyRecord record) { public void addLastResortKyberPreKey(ServiceIdType serviceIdType, KyberPreKeyRecord record) {
if (serviceIdType.equals(ServiceIdType.ACI)) { final var accountData = getAccountData(serviceIdType);
addAciLastResortKyberPreKey(record); final var preKeyMetadata = accountData.getPreKeyMetadata();
} else { if (preKeyMetadata.kyberPreKeyIdOffset != record.getId()) {
addPniLastResortKyberPreKey(record);
}
}
public void addAciLastResortKyberPreKey(KyberPreKeyRecord record) {
if (aciKyberPreKeyIdOffset != record.getId()) {
logger.error("Invalid last resort kyber pre key id {}, expected {}", logger.error("Invalid last resort kyber pre key id {}, expected {}",
record.getId(), record.getId(),
aciKyberPreKeyIdOffset); preKeyMetadata.kyberPreKeyIdOffset);
throw new AssertionError("Invalid last resort kyber pre key id"); throw new AssertionError("Invalid last resort kyber pre key id");
} }
getAciKyberPreKeyStore().storeLastResortKyberPreKey(record.getId(), record); accountData.getKyberPreKeyStore().storeLastResortKyberPreKey(record.getId(), record);
aciActiveLastResortKyberPreKeyId = record.getId(); preKeyMetadata.activeLastResortKyberPreKeyId = record.getId();
aciKyberPreKeyIdOffset = (aciKyberPreKeyIdOffset + 1) % PREKEY_MAXIMUM_ID; preKeyMetadata.kyberPreKeyIdOffset = (preKeyMetadata.kyberPreKeyIdOffset + 1) % PREKEY_MAXIMUM_ID;
save();
}
public void addPniLastResortKyberPreKey(KyberPreKeyRecord record) {
if (pniKyberPreKeyIdOffset != record.getId()) {
logger.error("Invalid last resort kyber pre key id {}, expected {}",
record.getId(),
pniKyberPreKeyIdOffset);
throw new AssertionError("Invalid last resort kyber pre key id");
}
getPniKyberPreKeyStore().storeLastResortKyberPreKey(record.getId(), record);
pniActiveLastResortKyberPreKeyId = record.getId();
pniKyberPreKeyIdOffset = (pniKyberPreKeyIdOffset + 1) % PREKEY_MAXIMUM_ID;
save(); save();
} }
@ -1209,6 +1120,23 @@ public class SignalAccount implements Closeable {
return previousStorageVersion; return previousStorageVersion;
} }
public AccountData getAccountData(ServiceIdType serviceIdType) {
return switch (serviceIdType) {
case ACI -> aciAccountData;
case PNI -> pniAccountData;
};
}
public AccountData getAccountData(ServiceId accountIdentifier) {
if (accountIdentifier.equals(aci)) {
return aciAccountData;
} else if (accountIdentifier.equals(pni)) {
return pniAccountData;
} else {
throw new IllegalArgumentException("No matching account data found for " + accountIdentifier);
}
}
public SignalServiceDataStore getSignalServiceDataStore() { public SignalServiceDataStore getSignalServiceDataStore() {
return new SignalServiceDataStore() { return new SignalServiceDataStore() {
@Override @Override
@ -1224,12 +1152,12 @@ public class SignalAccount implements Closeable {
@Override @Override
public SignalServiceAccountDataStore aci() { public SignalServiceAccountDataStore aci() {
return getAciSignalServiceAccountDataStore(); return aciAccountData.getSignalServiceAccountDataStore();
} }
@Override @Override
public SignalServiceAccountDataStore pni() { public SignalServiceAccountDataStore pni() {
return getPniSignalServiceAccountDataStore(); return pniAccountData.getSignalServiceAccountDataStore();
} }
@Override @Override
@ -1239,89 +1167,11 @@ public class SignalAccount implements Closeable {
}; };
} }
private SignalServiceAccountDataStore getAciSignalServiceAccountDataStore() {
return getOrCreate(() -> aciSignalProtocolStore,
() -> aciSignalProtocolStore = new SignalProtocolStore(getAciPreKeyStore(),
getAciSignedPreKeyStore(),
getAciKyberPreKeyStore(),
getAciSessionStore(),
getAciIdentityKeyStore(),
getSenderKeyStore(),
this::isMultiDevice));
}
private SignalServiceAccountDataStore getPniSignalServiceAccountDataStore() {
return getOrCreate(() -> pniSignalProtocolStore,
() -> pniSignalProtocolStore = new SignalProtocolStore(getPniPreKeyStore(),
getPniSignedPreKeyStore(),
getPniKyberPreKeyStore(),
getPniSessionStore(),
getPniIdentityKeyStore(),
getSenderKeyStore(),
this::isMultiDevice));
}
private PreKeyStore getAciPreKeyStore() {
return getOrCreate(() -> aciPreKeyStore,
() -> aciPreKeyStore = new PreKeyStore(getAccountDatabase(), ServiceIdType.ACI));
}
private SignedPreKeyStore getAciSignedPreKeyStore() {
return getOrCreate(() -> aciSignedPreKeyStore,
() -> aciSignedPreKeyStore = new SignedPreKeyStore(getAccountDatabase(), ServiceIdType.ACI));
}
private KyberPreKeyStore getAciKyberPreKeyStore() {
return getOrCreate(() -> aciKyberPreKeyStore,
() -> aciKyberPreKeyStore = new KyberPreKeyStore(getAccountDatabase(), ServiceIdType.ACI));
}
private PreKeyStore getPniPreKeyStore() {
return getOrCreate(() -> pniPreKeyStore,
() -> pniPreKeyStore = new PreKeyStore(getAccountDatabase(), ServiceIdType.PNI));
}
private SignedPreKeyStore getPniSignedPreKeyStore() {
return getOrCreate(() -> pniSignedPreKeyStore,
() -> pniSignedPreKeyStore = new SignedPreKeyStore(getAccountDatabase(), ServiceIdType.PNI));
}
private KyberPreKeyStore getPniKyberPreKeyStore() {
return getOrCreate(() -> pniKyberPreKeyStore,
() -> pniKyberPreKeyStore = new KyberPreKeyStore(getAccountDatabase(), ServiceIdType.PNI));
}
public SessionStore getAciSessionStore() {
return getOrCreate(() -> aciSessionStore,
() -> aciSessionStore = new SessionStore(getAccountDatabase(), ServiceIdType.ACI));
}
public SessionStore getPniSessionStore() {
return getOrCreate(() -> pniSessionStore,
() -> pniSessionStore = new SessionStore(getAccountDatabase(), ServiceIdType.PNI));
}
public IdentityKeyStore getIdentityKeyStore() { public IdentityKeyStore getIdentityKeyStore() {
return getOrCreate(() -> identityKeyStore, return getOrCreate(() -> identityKeyStore,
() -> identityKeyStore = new IdentityKeyStore(getAccountDatabase(), settings.trustNewIdentity())); () -> identityKeyStore = new IdentityKeyStore(getAccountDatabase(), settings.trustNewIdentity()));
} }
public SignalIdentityKeyStore getAciIdentityKeyStore() {
return getOrCreate(() -> aciIdentityKeyStore,
() -> aciIdentityKeyStore = new SignalIdentityKeyStore(getRecipientResolver(),
() -> aciIdentityKeyPair,
localRegistrationId,
getIdentityKeyStore()));
}
public SignalIdentityKeyStore getPniIdentityKeyStore() {
return getOrCreate(() -> pniIdentityKeyStore,
() -> pniIdentityKeyStore = new SignalIdentityKeyStore(getRecipientResolver(),
() -> pniIdentityKeyPair,
localRegistrationId,
getIdentityKeyStore()));
}
public GroupStore getGroupStore() { public GroupStore getGroupStore() {
return getOrCreate(() -> groupStore, return getOrCreate(() -> groupStore,
() -> groupStore = new GroupStore(getAccountDatabase(), () -> groupStore = new GroupStore(getAccountDatabase(),
@ -1489,11 +1339,7 @@ public class SignalAccount implements Closeable {
if (this.pni != null && !this.pni.equals(updatedPni)) { if (this.pni != null && !this.pni.equals(updatedPni)) {
// Clear data for old PNI // Clear data for old PNI
identityKeyStore.deleteIdentity(this.pni); identityKeyStore.deleteIdentity(this.pni);
getPniPreKeyStore().removeAllPreKeys(); clearAllPreKeys(ServiceIdType.PNI);
getPniSignedPreKeyStore().removeAllSignedPreKeys();
getPniKyberPreKeyStore().removeAllKyberPreKeys();
aciActiveLastResortKyberPreKeyId = -1;
pniActiveLastResortKyberPreKeyId = -1;
} }
this.pni = updatedPni; this.pni = updatedPni;
@ -1509,7 +1355,7 @@ public class SignalAccount implements Closeable {
setPni(updatedPni); setPni(updatedPni);
setPniIdentityKeyPair(pniIdentityKeyPair); setPniIdentityKeyPair(pniIdentityKeyPair);
addPniSignedPreKey(pniSignedPreKey); addSignedPreKey(ServiceIdType.PNI, pniSignedPreKey);
setLocalPniRegistrationId(localPniRegistrationId); setLocalPniRegistrationId(localPniRegistrationId);
} }
@ -1552,35 +1398,33 @@ public class SignalAccount implements Closeable {
} }
public IdentityKeyPair getIdentityKeyPair(ServiceIdType serviceIdType) { public IdentityKeyPair getIdentityKeyPair(ServiceIdType serviceIdType) {
return serviceIdType.equals(ServiceIdType.ACI) ? aciIdentityKeyPair : pniIdentityKeyPair; return getAccountData(serviceIdType).getIdentityKeyPair();
} }
public IdentityKeyPair getAciIdentityKeyPair() { public IdentityKeyPair getAciIdentityKeyPair() {
return aciIdentityKeyPair; return aciAccountData.getIdentityKeyPair();
} }
public IdentityKeyPair getPniIdentityKeyPair() { public IdentityKeyPair getPniIdentityKeyPair() {
return pniIdentityKeyPair; return pniAccountData.getIdentityKeyPair();
} }
public void setPniIdentityKeyPair(final IdentityKeyPair identityKeyPair) { public void setPniIdentityKeyPair(final IdentityKeyPair identityKeyPair) {
pniIdentityKeyPair = identityKeyPair; pniAccountData.setIdentityKeyPair(identityKeyPair);
final var pniPublicKey = identityKeyPair.getPublicKey(); trustSelfIdentity(ServiceIdType.PNI);
getIdentityKeyStore().saveIdentity(getPni(), pniPublicKey);
getIdentityKeyStore().setIdentityTrustLevel(getPni(), pniPublicKey, TrustLevel.TRUSTED_VERIFIED);
save(); save();
} }
public int getLocalRegistrationId() { public int getLocalRegistrationId() {
return localRegistrationId; return aciAccountData.getLocalRegistrationId();
} }
public int getLocalPniRegistrationId() { public int getLocalPniRegistrationId() {
return localPniRegistrationId; return pniAccountData.getLocalRegistrationId();
} }
public void setLocalPniRegistrationId(final int localPniRegistrationId) { public void setLocalPniRegistrationId(final int localPniRegistrationId) {
this.localPniRegistrationId = localPniRegistrationId; pniAccountData.setLocalRegistrationId(localPniRegistrationId);
save(); save();
} }
@ -1711,15 +1555,15 @@ public class SignalAccount implements Closeable {
} }
public int getPreKeyIdOffset(ServiceIdType serviceIdType) { public int getPreKeyIdOffset(ServiceIdType serviceIdType) {
return serviceIdType.equals(ServiceIdType.ACI) ? aciPreKeyIdOffset : pniPreKeyIdOffset; return getAccountData(serviceIdType).getPreKeyMetadata().preKeyIdOffset;
} }
public int getNextSignedPreKeyId(ServiceIdType serviceIdType) { public int getNextSignedPreKeyId(ServiceIdType serviceIdType) {
return serviceIdType.equals(ServiceIdType.ACI) ? aciNextSignedPreKeyId : pniNextSignedPreKeyId; return getAccountData(serviceIdType).getPreKeyMetadata().nextSignedPreKeyId;
} }
public int getKyberPreKeyIdOffset(ServiceIdType serviceIdType) { public int getKyberPreKeyIdOffset(ServiceIdType serviceIdType) {
return serviceIdType.equals(ServiceIdType.ACI) ? aciKyberPreKeyIdOffset : pniKyberPreKeyIdOffset; return getAccountData(serviceIdType).getPreKeyMetadata().kyberPreKeyIdOffset;
} }
public boolean isRegistered() { public boolean isRegistered() {
@ -1778,22 +1622,26 @@ public class SignalAccount implements Closeable {
save(); save();
clearAllPreKeys(); clearAllPreKeys();
getAciSessionStore().archiveAllSessions(); aciAccountData.getSessionStore().archiveAllSessions();
getPniSessionStore().archiveAllSessions(); pniAccountData.getSessionStore().archiveAllSessions();
getSenderKeyStore().deleteAll(); getSenderKeyStore().deleteAll();
getRecipientTrustedResolver().resolveSelfRecipientTrusted(getSelfRecipientAddress()); getRecipientTrustedResolver().resolveSelfRecipientTrusted(getSelfRecipientAddress());
final var aciPublicKey = getAciIdentityKeyPair().getPublicKey(); trustSelfIdentity(ServiceIdType.ACI);
getIdentityKeyStore().saveIdentity(getAci(), aciPublicKey);
getIdentityKeyStore().setIdentityTrustLevel(getAci(), aciPublicKey, TrustLevel.TRUSTED_VERIFIED);
if (getPniIdentityKeyPair() == null) { if (getPniIdentityKeyPair() == null) {
setPniIdentityKeyPair(KeyUtils.generateIdentityKeyPair()); setPniIdentityKeyPair(KeyUtils.generateIdentityKeyPair());
} else { } else {
final var pniPublicKey = getPniIdentityKeyPair().getPublicKey(); trustSelfIdentity(ServiceIdType.PNI);
getIdentityKeyStore().saveIdentity(getPni(), pniPublicKey);
getIdentityKeyStore().setIdentityTrustLevel(getPni(), pniPublicKey, TrustLevel.TRUSTED_VERIFIED);
} }
} }
private void trustSelfIdentity(ServiceIdType serviceIdType) {
final var accountData = getAccountData(serviceIdType);
final var serviceId = accountData.getServiceId();
final var publicKey = accountData.getIdentityKeyPair().getPublicKey();
getIdentityKeyStore().saveIdentity(serviceId, publicKey);
getIdentityKeyStore().setIdentityTrustLevel(serviceId, publicKey, TrustLevel.TRUSTED_VERIFIED);
}
public void deleteAccountData() throws IOException { public void deleteAccountData() throws IOException {
close(); close();
try (final var files = Files.walk(getUserPath(dataPath, accountPath).toPath()) try (final var files = Files.walk(getUserPath(dataPath, accountPath).toPath())
@ -1850,4 +1698,95 @@ public class SignalAccount implements Closeable {
void call(); void call();
} }
private static class PreKeyMetadata {
private int preKeyIdOffset = 1;
private int nextSignedPreKeyId = 1;
private int kyberPreKeyIdOffset = 1;
private int activeLastResortKyberPreKeyId = -1;
}
public class AccountData {
private final ServiceIdType serviceIdType;
private IdentityKeyPair identityKeyPair;
private int localRegistrationId;
private final PreKeyMetadata preKeyMetadata = new PreKeyMetadata();
private SignalProtocolStore signalProtocolStore;
private PreKeyStore preKeyStore;
private SignedPreKeyStore signedPreKeyStore;
private KyberPreKeyStore kyberPreKeyStore;
private SessionStore sessionStore;
private SignalIdentityKeyStore identityKeyStore;
public AccountData(final ServiceIdType serviceIdType) {
this.serviceIdType = serviceIdType;
}
public ServiceId getServiceId() {
return getAccountId(serviceIdType);
}
public IdentityKeyPair getIdentityKeyPair() {
return identityKeyPair;
}
private void setIdentityKeyPair(final IdentityKeyPair identityKeyPair) {
this.identityKeyPair = identityKeyPair;
}
public int getLocalRegistrationId() {
return localRegistrationId;
}
private void setLocalRegistrationId(final int localRegistrationId) {
this.localRegistrationId = localRegistrationId;
this.identityKeyStore = null;
}
public PreKeyMetadata getPreKeyMetadata() {
return preKeyMetadata;
}
private SignalServiceAccountDataStore getSignalServiceAccountDataStore() {
return getOrCreate(() -> signalProtocolStore,
() -> signalProtocolStore = new SignalProtocolStore(getPreKeyStore(),
getSignedPreKeyStore(),
getKyberPreKeyStore(),
getSessionStore(),
getIdentityKeyStore(),
getSenderKeyStore(),
SignalAccount.this::isMultiDevice));
}
private PreKeyStore getPreKeyStore() {
return getOrCreate(() -> preKeyStore,
() -> preKeyStore = new PreKeyStore(getAccountDatabase(), serviceIdType));
}
private SignedPreKeyStore getSignedPreKeyStore() {
return getOrCreate(() -> signedPreKeyStore,
() -> signedPreKeyStore = new SignedPreKeyStore(getAccountDatabase(), serviceIdType));
}
private KyberPreKeyStore getKyberPreKeyStore() {
return getOrCreate(() -> kyberPreKeyStore,
() -> kyberPreKeyStore = new KyberPreKeyStore(getAccountDatabase(), serviceIdType));
}
public SessionStore getSessionStore() {
return getOrCreate(() -> sessionStore,
() -> sessionStore = new SessionStore(getAccountDatabase(), serviceIdType));
}
public SignalIdentityKeyStore getIdentityKeyStore() {
return getOrCreate(() -> identityKeyStore,
() -> identityKeyStore = new SignalIdentityKeyStore(getRecipientResolver(),
() -> identityKeyPair,
localRegistrationId,
SignalAccount.this.getIdentityKeyStore()));
}
}
} }

View file

@ -120,4 +120,8 @@ public class KeyUtils {
secureRandom.nextBytes(secret); secureRandom.nextBytes(secret);
return secret; return secret;
} }
public static int getRandomInt(int bound) {
return secureRandom.nextInt(bound);
}
} }