Key tables on serviceId instead of recipientId

This commit is contained in:
AsamK 2022-08-23 14:17:14 +02:00
parent 04fa046815
commit 280d8d7f10
25 changed files with 528 additions and 514 deletions

View file

@ -69,6 +69,7 @@ import org.whispersystems.signalservice.api.messages.SignalServicePreview;
import org.whispersystems.signalservice.api.messages.SignalServiceReceiptMessage; import org.whispersystems.signalservice.api.messages.SignalServiceReceiptMessage;
import org.whispersystems.signalservice.api.messages.SignalServiceTypingMessage; import org.whispersystems.signalservice.api.messages.SignalServiceTypingMessage;
import org.whispersystems.signalservice.api.push.ACI; import org.whispersystems.signalservice.api.push.ACI;
import org.whispersystems.signalservice.api.push.ServiceId;
import org.whispersystems.signalservice.api.util.DeviceNameUtil; import org.whispersystems.signalservice.api.util.DeviceNameUtil;
import org.whispersystems.signalservice.api.util.InvalidNumberException; import org.whispersystems.signalservice.api.util.InvalidNumberException;
import org.whispersystems.signalservice.api.util.PhoneNumberFormatter; import org.whispersystems.signalservice.api.util.PhoneNumberFormatter;
@ -165,11 +166,12 @@ class ManagerImpl implements Manager {
this.notifyAll(); this.notifyAll();
} }
}); });
disposable.add(account.getIdentityKeyStore().getIdentityChanges().subscribe(recipientId -> { disposable.add(account.getIdentityKeyStore().getIdentityChanges().subscribe(serviceId -> {
logger.trace("Archiving old sessions for {}", recipientId); logger.trace("Archiving old sessions for {}", serviceId);
account.getAciSessionStore().archiveSessions(recipientId); account.getAciSessionStore().archiveSessions(serviceId);
account.getPniSessionStore().archiveSessions(recipientId); account.getPniSessionStore().archiveSessions(serviceId);
account.getSenderKeyStore().deleteSharedWith(recipientId); account.getSenderKeyStore().deleteSharedWith(serviceId);
final var recipientId = account.getRecipientResolver().resolveRecipient(serviceId);
final var profile = account.getProfileStore().getProfile(recipientId); final var profile = account.getProfileStore().getProfile(recipientId);
if (profile != null) { if (profile != null) {
account.getProfileStore() account.getProfileStore()
@ -631,7 +633,11 @@ class ManagerImpl implements Manager {
if (recipient instanceof RecipientIdentifier.Single r) { if (recipient instanceof RecipientIdentifier.Single r) {
try { try {
final var recipientId = context.getRecipientHelper().resolveRecipient(r); final var recipientId = context.getRecipientHelper().resolveRecipient(r);
account.getMessageSendLogStore().deleteEntryForRecipientNonGroup(targetSentTimestamp, recipientId); account.getMessageSendLogStore()
.deleteEntryForRecipientNonGroup(targetSentTimestamp,
account.getRecipientAddressResolver()
.resolveRecipientAddress(recipientId)
.getServiceId());
} catch (UnregisteredRecipientException ignored) { } catch (UnregisteredRecipientException ignored) {
} }
} else if (recipient instanceof RecipientIdentifier.Group r) { } else if (recipient instanceof RecipientIdentifier.Group r) {
@ -689,7 +695,11 @@ class ManagerImpl implements Manager {
} catch (UnregisteredRecipientException e) { } catch (UnregisteredRecipientException e) {
continue; continue;
} }
account.getAciSessionStore().deleteAllSessions(recipientId); final var serviceId = context.getAccount()
.getRecipientAddressResolver()
.resolveRecipientAddress(recipientId)
.getServiceId();
account.getAciSessionStore().deleteAllSessions(serviceId);
} }
} }
} }
@ -1035,13 +1045,13 @@ class ManagerImpl implements Manager {
} }
final var address = account.getRecipientAddressResolver() final var address = account.getRecipientAddressResolver()
.resolveRecipientAddress(identityInfo.getRecipientId()); .resolveRecipientAddress(account.getRecipientResolver().resolveRecipient(identityInfo.getServiceId()));
final var scannableFingerprint = context.getIdentityHelper() final var scannableFingerprint = context.getIdentityHelper()
.computeSafetyNumberForScanning(identityInfo.getRecipientId(), identityInfo.getIdentityKey()); .computeSafetyNumberForScanning(identityInfo.getServiceId(), identityInfo.getIdentityKey());
return new Identity(address, return new Identity(address,
identityInfo.getIdentityKey(), identityInfo.getIdentityKey(),
context.getIdentityHelper() context.getIdentityHelper()
.computeSafetyNumber(identityInfo.getRecipientId(), identityInfo.getIdentityKey()), .computeSafetyNumber(identityInfo.getServiceId(), identityInfo.getIdentityKey()),
scannableFingerprint == null ? null : scannableFingerprint.getSerialized(), scannableFingerprint == null ? null : scannableFingerprint.getSerialized(),
identityInfo.getTrustLevel(), identityInfo.getTrustLevel(),
identityInfo.getDateAddedTimestamp()); identityInfo.getDateAddedTimestamp());
@ -1049,13 +1059,15 @@ class ManagerImpl implements Manager {
@Override @Override
public List<Identity> getIdentities(RecipientIdentifier.Single recipient) { public List<Identity> getIdentities(RecipientIdentifier.Single recipient) {
IdentityInfo identity; ServiceId serviceId;
try { try {
identity = account.getIdentityKeyStore() serviceId = account.getRecipientAddressResolver()
.getIdentityInfo(context.getRecipientHelper().resolveRecipient(recipient)); .resolveRecipientAddress(context.getRecipientHelper().resolveRecipient(recipient))
.getServiceId();
} catch (UnregisteredRecipientException e) { } catch (UnregisteredRecipientException e) {
identity = null; return List.of();
} }
final var identity = account.getIdentityKeyStore().getIdentityInfo(serviceId);
return identity == null ? List.of() : List.of(toIdentity(identity)); return identity == null ? List.of() : List.of(toIdentity(identity));
} }

View file

@ -2,18 +2,21 @@ package org.asamk.signal.manager.actions;
import org.asamk.signal.manager.helper.Context; import org.asamk.signal.manager.helper.Context;
import org.asamk.signal.manager.storage.recipients.RecipientId; import org.asamk.signal.manager.storage.recipients.RecipientId;
import org.whispersystems.signalservice.api.push.ServiceId;
public class RenewSessionAction implements HandleAction { public class RenewSessionAction implements HandleAction {
private final RecipientId recipientId; private final RecipientId recipientId;
private final ServiceId serviceId;
public RenewSessionAction(final RecipientId recipientId) { public RenewSessionAction(final RecipientId recipientId, final ServiceId serviceId) {
this.recipientId = recipientId; this.recipientId = recipientId;
this.serviceId = serviceId;
} }
@Override @Override
public void execute(Context context) throws Throwable { public void execute(Context context) throws Throwable {
context.getAccount().getAciSessionStore().archiveSessions(recipientId); context.getAccount().getAciSessionStore().archiveSessions(serviceId);
if (!recipientId.equals(context.getAccount().getSelfRecipientId())) { if (!recipientId.equals(context.getAccount().getSelfRecipientId())) {
context.getSendHelper().sendNullMessage(recipientId); context.getSendHelper().sendNullMessage(recipientId);
} }

View file

@ -7,6 +7,7 @@ import org.signal.libsignal.metadata.ProtocolException;
import org.signal.libsignal.protocol.message.CiphertextMessage; import org.signal.libsignal.protocol.message.CiphertextMessage;
import org.signal.libsignal.protocol.message.DecryptionErrorMessage; import org.signal.libsignal.protocol.message.DecryptionErrorMessage;
import org.whispersystems.signalservice.api.messages.SignalServiceEnvelope; import org.whispersystems.signalservice.api.messages.SignalServiceEnvelope;
import org.whispersystems.signalservice.api.push.ServiceId;
import org.whispersystems.signalservice.internal.push.SignalServiceProtos; import org.whispersystems.signalservice.internal.push.SignalServiceProtos;
import java.util.Optional; import java.util.Optional;
@ -14,22 +15,25 @@ import java.util.Optional;
public class SendRetryMessageRequestAction implements HandleAction { public class SendRetryMessageRequestAction implements HandleAction {
private final RecipientId recipientId; private final RecipientId recipientId;
private final ServiceId serviceId;
private final ProtocolException protocolException; private final ProtocolException protocolException;
private final SignalServiceEnvelope envelope; private final SignalServiceEnvelope envelope;
public SendRetryMessageRequestAction( public SendRetryMessageRequestAction(
final RecipientId recipientId, final RecipientId recipientId,
final ServiceId serviceId,
final ProtocolException protocolException, final ProtocolException protocolException,
final SignalServiceEnvelope envelope final SignalServiceEnvelope envelope
) { ) {
this.recipientId = recipientId; this.recipientId = recipientId;
this.serviceId = serviceId;
this.protocolException = protocolException; this.protocolException = protocolException;
this.envelope = envelope; this.envelope = envelope;
} }
@Override @Override
public void execute(Context context) throws Throwable { public void execute(Context context) throws Throwable {
context.getAccount().getAciSessionStore().archiveSessions(recipientId); context.getAccount().getAciSessionStore().archiveSessions(serviceId);
int senderDevice = protocolException.getSenderDevice(); int senderDevice = protocolException.getSenderDevice();
Optional<GroupId> groupId = protocolException.getGroupId().isPresent() ? Optional.of(GroupId.unknownVersion( Optional<GroupId> groupId = protocolException.getGroupId().isPresent() ? Optional.of(GroupId.unknownVersion(

View file

@ -2,6 +2,7 @@ package org.asamk.signal.manager.helper;
import org.asamk.signal.manager.api.TrustLevel; import org.asamk.signal.manager.api.TrustLevel;
import org.asamk.signal.manager.storage.SignalAccount; import org.asamk.signal.manager.storage.SignalAccount;
import org.asamk.signal.manager.storage.recipients.RecipientAddress;
import org.asamk.signal.manager.storage.recipients.RecipientId; import org.asamk.signal.manager.storage.recipients.RecipientId;
import org.asamk.signal.manager.util.Utils; import org.asamk.signal.manager.util.Utils;
import org.signal.libsignal.protocol.IdentityKey; import org.signal.libsignal.protocol.IdentityKey;
@ -12,7 +13,7 @@ import org.signal.libsignal.protocol.fingerprint.ScannableFingerprint;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.whispersystems.signalservice.api.messages.SendMessageResult; import org.whispersystems.signalservice.api.messages.SendMessageResult;
import org.whispersystems.signalservice.api.push.SignalServiceAddress; import org.whispersystems.signalservice.api.push.ServiceId;
import java.io.IOException; import java.io.IOException;
import java.util.Arrays; import java.util.Arrays;
@ -33,20 +34,23 @@ public class IdentityHelper {
} }
public boolean trustIdentityVerified(RecipientId recipientId, byte[] fingerprint) { public boolean trustIdentityVerified(RecipientId recipientId, byte[] fingerprint) {
return trustIdentity(recipientId, final var serviceId = account.getRecipientAddressResolver().resolveRecipientAddress(recipientId).getServiceId();
return trustIdentity(serviceId,
identityKey -> Arrays.equals(identityKey.serialize(), fingerprint), identityKey -> Arrays.equals(identityKey.serialize(), fingerprint),
TrustLevel.TRUSTED_VERIFIED); TrustLevel.TRUSTED_VERIFIED);
} }
public boolean trustIdentityVerifiedSafetyNumber(RecipientId recipientId, String safetyNumber) { public boolean trustIdentityVerifiedSafetyNumber(RecipientId recipientId, String safetyNumber) {
return trustIdentity(recipientId, final var serviceId = account.getRecipientAddressResolver().resolveRecipientAddress(recipientId).getServiceId();
identityKey -> safetyNumber.equals(computeSafetyNumber(recipientId, identityKey)), return trustIdentity(serviceId,
identityKey -> safetyNumber.equals(computeSafetyNumber(serviceId, identityKey)),
TrustLevel.TRUSTED_VERIFIED); TrustLevel.TRUSTED_VERIFIED);
} }
public boolean trustIdentityVerifiedSafetyNumber(RecipientId recipientId, byte[] safetyNumber) { public boolean trustIdentityVerifiedSafetyNumber(RecipientId recipientId, byte[] safetyNumber) {
return trustIdentity(recipientId, identityKey -> { final var serviceId = account.getRecipientAddressResolver().resolveRecipientAddress(recipientId).getServiceId();
final var fingerprint = computeSafetyNumberForScanning(recipientId, identityKey); return trustIdentity(serviceId, identityKey -> {
final var fingerprint = computeSafetyNumberForScanning(serviceId, identityKey);
try { try {
return fingerprint != null && fingerprint.compareTo(safetyNumber); return fingerprint != null && fingerprint.compareTo(safetyNumber);
} catch (FingerprintVersionMismatchException | FingerprintParsingException e) { } catch (FingerprintVersionMismatchException | FingerprintParsingException e) {
@ -56,35 +60,39 @@ public class IdentityHelper {
} }
public boolean trustIdentityAllKeys(RecipientId recipientId) { public boolean trustIdentityAllKeys(RecipientId recipientId) {
return trustIdentity(recipientId, identityKey -> true, TrustLevel.TRUSTED_UNVERIFIED); final var serviceId = account.getRecipientAddressResolver().resolveRecipientAddress(recipientId).getServiceId();
return trustIdentity(serviceId, identityKey -> true, TrustLevel.TRUSTED_UNVERIFIED);
} }
public String computeSafetyNumber(RecipientId recipientId, IdentityKey theirIdentityKey) { public String computeSafetyNumber(ServiceId serviceId, IdentityKey theirIdentityKey) {
var address = context.getRecipientHelper().resolveSignalServiceAddress(recipientId); final Fingerprint fingerprint = computeSafetyNumberFingerprint(serviceId, theirIdentityKey);
final Fingerprint fingerprint = computeSafetyNumberFingerprint(address, theirIdentityKey);
return fingerprint == null ? null : fingerprint.getDisplayableFingerprint().getDisplayText(); return fingerprint == null ? null : fingerprint.getDisplayableFingerprint().getDisplayText();
} }
public ScannableFingerprint computeSafetyNumberForScanning(RecipientId recipientId, IdentityKey theirIdentityKey) { public ScannableFingerprint computeSafetyNumberForScanning(ServiceId serviceId, IdentityKey theirIdentityKey) {
var address = context.getRecipientHelper().resolveSignalServiceAddress(recipientId); final Fingerprint fingerprint = computeSafetyNumberFingerprint(serviceId, theirIdentityKey);
final Fingerprint fingerprint = computeSafetyNumberFingerprint(address, theirIdentityKey);
return fingerprint == null ? null : fingerprint.getScannableFingerprint(); return fingerprint == null ? null : fingerprint.getScannableFingerprint();
} }
private Fingerprint computeSafetyNumberFingerprint( private Fingerprint computeSafetyNumberFingerprint(
final SignalServiceAddress theirAddress, final IdentityKey theirIdentityKey final ServiceId serviceId, final IdentityKey theirIdentityKey
) { ) {
final var address = account.getRecipientAddressResolver()
.resolveRecipientAddress(account.getRecipientResolver().resolveRecipient(serviceId));
return Utils.computeSafetyNumber(capabilities.isUuid(), return Utils.computeSafetyNumber(capabilities.isUuid(),
account.getSelfAddress(), account.getSelfRecipientAddress(),
account.getAciIdentityKeyPair().getPublicKey(), account.getAciIdentityKeyPair().getPublicKey(),
theirAddress, address.getServiceId().equals(serviceId)
? address
: new RecipientAddress(serviceId.uuid(), address.number().orElse(null)),
theirIdentityKey); theirIdentityKey);
} }
private boolean trustIdentity( private boolean trustIdentity(
RecipientId recipientId, Function<IdentityKey, Boolean> verifier, TrustLevel trustLevel ServiceId serviceId, Function<IdentityKey, Boolean> verifier, TrustLevel trustLevel
) { ) {
var identity = account.getIdentityKeyStore().getIdentityInfo(recipientId); var identity = account.getIdentityKeyStore().getIdentityInfo(serviceId);
if (identity == null) { if (identity == null) {
return false; return false;
} }
@ -93,9 +101,11 @@ public class IdentityHelper {
return false; return false;
} }
account.getIdentityKeyStore().setIdentityTrustLevel(recipientId, identity.getIdentityKey(), trustLevel); account.getIdentityKeyStore().setIdentityTrustLevel(serviceId, identity.getIdentityKey(), trustLevel);
try { try {
var address = context.getRecipientHelper().resolveSignalServiceAddress(recipientId); final var address = account.getRecipientAddressResolver()
.resolveRecipientAddress(account.getRecipientResolver().resolveRecipient(serviceId))
.toSignalServiceAddress();
context.getSyncHelper().sendVerifiedMessage(address, identity.getIdentityKey(), trustLevel); context.getSyncHelper().sendVerifiedMessage(address, identity.getIdentityKey(), trustLevel);
} catch (IOException e) { } catch (IOException e) {
logger.warn("Failed to send verification sync message: {}", e.getMessage()); logger.warn("Failed to send verification sync message: {}", e.getMessage());
@ -105,11 +115,13 @@ public class IdentityHelper {
} }
public void handleIdentityFailure( public void handleIdentityFailure(
final RecipientId recipientId, final SendMessageResult.IdentityFailure identityFailure final RecipientId recipientId,
final ServiceId serviceId,
final SendMessageResult.IdentityFailure identityFailure
) { ) {
final var identityKey = identityFailure.getIdentityKey(); final var identityKey = identityFailure.getIdentityKey();
if (identityKey != null) { if (identityKey != null) {
account.getIdentityKeyStore().saveIdentity(recipientId, identityKey); account.getIdentityKeyStore().saveIdentity(serviceId, identityKey);
} else { } else {
// Retrieve profile to get the current identity key from the server // Retrieve profile to get the current identity key from the server
context.getProfileHelper().refreshRecipientProfile(recipientId); context.getProfileHelper().refreshRecipientProfile(recipientId);

View file

@ -42,7 +42,6 @@ import org.signal.libsignal.metadata.ProtocolInvalidMessageException;
import org.signal.libsignal.metadata.ProtocolNoSessionException; import org.signal.libsignal.metadata.ProtocolNoSessionException;
import org.signal.libsignal.metadata.ProtocolUntrustedIdentityException; import org.signal.libsignal.metadata.ProtocolUntrustedIdentityException;
import org.signal.libsignal.metadata.SelfSendException; import org.signal.libsignal.metadata.SelfSendException;
import org.signal.libsignal.protocol.SignalProtocolAddress;
import org.signal.libsignal.protocol.message.DecryptionErrorMessage; import org.signal.libsignal.protocol.message.DecryptionErrorMessage;
import org.signal.libsignal.zkgroup.InvalidInputException; import org.signal.libsignal.zkgroup.InvalidInputException;
import org.signal.libsignal.zkgroup.profiles.ProfileKey; import org.signal.libsignal.zkgroup.profiles.ProfileKey;
@ -57,6 +56,7 @@ import org.whispersystems.signalservice.api.messages.SignalServiceReceiptMessage
import org.whispersystems.signalservice.api.messages.SignalServiceStoryMessage; import org.whispersystems.signalservice.api.messages.SignalServiceStoryMessage;
import org.whispersystems.signalservice.api.messages.multidevice.SignalServiceSyncMessage; import org.whispersystems.signalservice.api.messages.multidevice.SignalServiceSyncMessage;
import org.whispersystems.signalservice.api.messages.multidevice.StickerPackOperationMessage; import org.whispersystems.signalservice.api.messages.multidevice.StickerPackOperationMessage;
import org.whispersystems.signalservice.api.push.ServiceId;
import org.whispersystems.signalservice.api.push.SignalServiceAddress; import org.whispersystems.signalservice.api.push.SignalServiceAddress;
import java.util.ArrayList; import java.util.ArrayList;
@ -137,16 +137,17 @@ public final class IncomingMessageHandler {
} else { } else {
final var senderProfile = context.getProfileHelper().getRecipientProfile(sender); final var senderProfile = context.getProfileHelper().getRecipientProfile(sender);
final var selfProfile = context.getProfileHelper().getSelfProfile(); final var selfProfile = context.getProfileHelper().getSelfProfile();
final var serviceId = ServiceId.parseOrThrow(e.getSender());
if ((!sender.equals(account.getSelfRecipientId()) || e.getSenderDevice() != account.getDeviceId()) if ((!sender.equals(account.getSelfRecipientId()) || e.getSenderDevice() != account.getDeviceId())
&& senderProfile != null && senderProfile != null
&& senderProfile.getCapabilities().contains(Profile.Capability.senderKey) && senderProfile.getCapabilities().contains(Profile.Capability.senderKey)
&& selfProfile != null && selfProfile != null
&& selfProfile.getCapabilities().contains(Profile.Capability.senderKey)) { && selfProfile.getCapabilities().contains(Profile.Capability.senderKey)) {
logger.debug("Received invalid message, requesting message resend."); logger.debug("Received invalid message, requesting message resend.");
actions.add(new SendRetryMessageRequestAction(sender, e, envelope)); actions.add(new SendRetryMessageRequestAction(sender, serviceId, e, envelope));
} else { } else {
logger.debug("Received invalid message, queuing renew session action."); logger.debug("Received invalid message, queuing renew session action.");
actions.add(new RenewSessionAction(sender)); actions.add(new RenewSessionAction(sender, serviceId));
} }
} }
exception = e; exception = e;
@ -176,9 +177,9 @@ public final class IncomingMessageHandler {
account.getRecipientTrustedResolver().resolveRecipientTrusted(content.getSender()); account.getRecipientTrustedResolver().resolveRecipientTrusted(content.getSender());
} }
if (envelope.isReceipt()) { if (envelope.isReceipt()) {
final var senderPair = getSender(envelope, content); final var senderDeviceAddress = getSender(envelope, content);
final var sender = senderPair.first(); final var sender = senderDeviceAddress.serviceId();
final var senderDeviceId = senderPair.second(); final var senderDeviceId = senderDeviceAddress.deviceId();
account.getMessageSendLogStore().deleteEntryForRecipient(envelope.getTimestamp(), sender, senderDeviceId); account.getMessageSendLogStore().deleteEntryForRecipient(envelope.getTimestamp(), sender, senderDeviceId);
} }
@ -211,24 +212,23 @@ public final class IncomingMessageHandler {
SignalServiceEnvelope envelope, SignalServiceContent content, ReceiveConfig receiveConfig SignalServiceEnvelope envelope, SignalServiceContent content, ReceiveConfig receiveConfig
) { ) {
var actions = new ArrayList<HandleAction>(); var actions = new ArrayList<HandleAction>();
final var senderPair = getSender(envelope, content); final var senderDeviceAddress = getSender(envelope, content);
final var sender = senderPair.first(); final var sender = senderDeviceAddress.recipientId();
final var senderDeviceId = senderPair.second(); final var senderServiceId = senderDeviceAddress.serviceId();
final var senderDeviceId = senderDeviceAddress.deviceId();
final var destination = getDestination(envelope); final var destination = getDestination(envelope);
if (content.getReceiptMessage().isPresent()) { if (content.getReceiptMessage().isPresent()) {
final var message = content.getReceiptMessage().get(); final var message = content.getReceiptMessage().get();
if (message.isDeliveryReceipt()) { if (message.isDeliveryReceipt()) {
account.getMessageSendLogStore() account.getMessageSendLogStore()
.deleteEntriesForRecipient(message.getTimestamps(), sender, senderDeviceId); .deleteEntriesForRecipient(message.getTimestamps(), senderServiceId, senderDeviceId);
} }
} }
if (content.getSenderKeyDistributionMessage().isPresent()) { if (content.getSenderKeyDistributionMessage().isPresent()) {
final var message = content.getSenderKeyDistributionMessage().get(); final var message = content.getSenderKeyDistributionMessage().get();
final var protocolAddress = new SignalProtocolAddress(context.getRecipientHelper() final var protocolAddress = senderServiceId.toProtocolAddress(senderDeviceId);
.resolveSignalServiceAddress(sender)
.getIdentifier(), senderDeviceId);
logger.debug("Received a sender key distribution message for distributionId {} from {}", logger.debug("Received a sender key distribution message for distributionId {} from {}",
message.getDistributionId(), message.getDistributionId(),
protocolAddress); protocolAddress);
@ -242,7 +242,7 @@ public final class IncomingMessageHandler {
senderDeviceId, senderDeviceId,
message.getTimestamp()); message.getTimestamp());
if (message.getDeviceId() == account.getDeviceId()) { if (message.getDeviceId() == account.getDeviceId()) {
handleDecryptionErrorMessage(actions, sender, senderDeviceId, message); handleDecryptionErrorMessage(actions, sender, senderServiceId, senderDeviceId, message);
} else { } else {
logger.debug("Request is for another one of our devices"); logger.debug("Request is for another one of our devices");
} }
@ -274,7 +274,7 @@ public final class IncomingMessageHandler {
actions.addAll(handleSignalServiceDataMessage(message, actions.addAll(handleSignalServiceDataMessage(message,
false, false,
sender, senderDeviceAddress,
destination, destination,
receiveConfig.ignoreAttachments())); receiveConfig.ignoreAttachments()));
} }
@ -286,7 +286,7 @@ public final class IncomingMessageHandler {
if (content.getSyncMessage().isPresent()) { if (content.getSyncMessage().isPresent()) {
var syncMessage = content.getSyncMessage().get(); var syncMessage = content.getSyncMessage().get();
actions.addAll(handleSyncMessage(syncMessage, sender, receiveConfig.ignoreAttachments())); actions.addAll(handleSyncMessage(syncMessage, senderDeviceAddress, receiveConfig.ignoreAttachments()));
} }
return actions; return actions;
@ -295,11 +295,15 @@ public final class IncomingMessageHandler {
private void handleDecryptionErrorMessage( private void handleDecryptionErrorMessage(
final List<HandleAction> actions, final List<HandleAction> actions,
final RecipientId sender, final RecipientId sender,
final ServiceId senderServiceId,
final int senderDeviceId, final int senderDeviceId,
final DecryptionErrorMessage message final DecryptionErrorMessage message
) { ) {
final var logEntries = account.getMessageSendLogStore() final var logEntries = account.getMessageSendLogStore()
.findMessages(sender, senderDeviceId, message.getTimestamp(), message.getRatchetKey().isEmpty()); .findMessages(senderServiceId,
senderDeviceId,
message.getTimestamp(),
message.getRatchetKey().isEmpty());
for (final var logEntry : logEntries) { for (final var logEntry : logEntries) {
actions.add(new ResendMessageAction(sender, message.getTimestamp(), logEntry)); actions.add(new ResendMessageAction(sender, message.getTimestamp(), logEntry));
@ -307,13 +311,13 @@ public final class IncomingMessageHandler {
if (message.getRatchetKey().isPresent()) { if (message.getRatchetKey().isPresent()) {
if (account.getAciSessionStore() if (account.getAciSessionStore()
.isCurrentRatchetKey(sender, senderDeviceId, message.getRatchetKey().get())) { .isCurrentRatchetKey(senderServiceId, senderDeviceId, message.getRatchetKey().get())) {
if (logEntries.isEmpty()) { if (logEntries.isEmpty()) {
logger.debug("Renewing the session with sender"); logger.debug("Renewing the session with sender");
actions.add(new RenewSessionAction(sender)); actions.add(new RenewSessionAction(sender, senderServiceId));
} else { } else {
logger.trace("Archiving the session with sender, a resend message has already been queued"); logger.trace("Archiving the session with sender, a resend message has already been queued");
context.getAccount().getAciSessionStore().archiveSessions(sender); context.getAccount().getAciSessionStore().archiveSessions(senderServiceId);
} }
} }
return; return;
@ -333,16 +337,16 @@ public final class IncomingMessageHandler {
sender, sender,
senderDeviceId, senderDeviceId,
group.getDistributionId()); group.getDistributionId());
account.getSenderKeyStore().deleteSharedWith(sender, senderDeviceId, group.getDistributionId()); account.getSenderKeyStore().deleteSharedWith(senderServiceId, senderDeviceId, group.getDistributionId());
} }
if (!found) { if (!found) {
logger.debug("Reset all shared sender keys with this recipient, no related message found in send log"); logger.debug("Reset all shared sender keys with this recipient, no related message found in send log");
account.getSenderKeyStore().deleteSharedWith(sender); account.getSenderKeyStore().deleteSharedWith(senderServiceId);
} }
} }
private List<HandleAction> handleSyncMessage( private List<HandleAction> handleSyncMessage(
final SignalServiceSyncMessage syncMessage, final RecipientId sender, final boolean ignoreAttachments final SignalServiceSyncMessage syncMessage, final DeviceAddress sender, final boolean ignoreAttachments
) { ) {
var actions = new ArrayList<HandleAction>(); var actions = new ArrayList<HandleAction>();
account.setMultiDevice(true); account.setMultiDevice(true);
@ -353,12 +357,16 @@ public final class IncomingMessageHandler {
actions.addAll(handleSignalServiceDataMessage(message.getDataMessage().get(), actions.addAll(handleSignalServiceDataMessage(message.getDataMessage().get(),
true, true,
sender, sender,
destination == null ? null : context.getRecipientHelper().resolveRecipient(destination), destination == null
? null
: new DeviceAddress(context.getRecipientHelper().resolveRecipient(destination),
destination.getServiceId(),
0),
ignoreAttachments)); ignoreAttachments));
} }
if (message.getStoryMessage().isPresent()) { if (message.getStoryMessage().isPresent()) {
actions.addAll(handleSignalServiceStoryMessage(message.getStoryMessage().get(), actions.addAll(handleSignalServiceStoryMessage(message.getStoryMessage().get(),
sender, sender.recipientId(),
ignoreAttachments)); ignoreAttachments));
} }
} }
@ -423,8 +431,7 @@ public final class IncomingMessageHandler {
if (syncMessage.getVerified().isPresent()) { if (syncMessage.getVerified().isPresent()) {
final var verifiedMessage = syncMessage.getVerified().get(); final var verifiedMessage = syncMessage.getVerified().get();
account.getIdentityKeyStore() account.getIdentityKeyStore()
.setIdentityTrustLevel(account.getRecipientTrustedResolver() .setIdentityTrustLevel(verifiedMessage.getDestination().getServiceId(),
.resolveRecipientTrusted(verifiedMessage.getDestination()),
verifiedMessage.getIdentityKey(), verifiedMessage.getIdentityKey(),
TrustLevel.fromVerifiedState(verifiedMessage.getVerified())); TrustLevel.fromVerifiedState(verifiedMessage.getVerified()));
} }
@ -575,8 +582,8 @@ public final class IncomingMessageHandler {
private List<HandleAction> handleSignalServiceDataMessage( private List<HandleAction> handleSignalServiceDataMessage(
SignalServiceDataMessage message, SignalServiceDataMessage message,
boolean isSync, boolean isSync,
RecipientId source, DeviceAddress source,
RecipientId destination, DeviceAddress destination,
boolean ignoreAttachments boolean ignoreAttachments
) { ) {
var actions = new ArrayList<HandleAction>(); var actions = new ArrayList<HandleAction>();
@ -616,19 +623,19 @@ public final class IncomingMessageHandler {
} }
case DELIVER: case DELIVER:
if (groupV1 == null && !isSync) { if (groupV1 == null && !isSync) {
actions.add(new SendGroupInfoRequestAction(source, groupId)); actions.add(new SendGroupInfoRequestAction(source.recipientId(), groupId));
} }
break; break;
case QUIT: { case QUIT: {
if (groupV1 != null) { if (groupV1 != null) {
groupV1.removeMember(source); groupV1.removeMember(source.recipientId());
account.getGroupStore().updateGroup(groupV1); account.getGroupStore().updateGroup(groupV1);
} }
break; break;
} }
case REQUEST_INFO: case REQUEST_INFO:
if (groupV1 != null && !isSync) { if (groupV1 != null && !isSync) {
actions.add(new SendGroupInfoAction(source, groupV1.getGroupId())); actions.add(new SendGroupInfoAction(source.recipientId(), groupV1.getGroupId()));
} }
break; break;
} }
@ -643,7 +650,7 @@ public final class IncomingMessageHandler {
final var conversationPartnerAddress = isSync ? destination : source; final var conversationPartnerAddress = isSync ? destination : source;
if (conversationPartnerAddress != null && message.isEndSession()) { if (conversationPartnerAddress != null && message.isEndSession()) {
account.getAciSessionStore().deleteAllSessions(conversationPartnerAddress); account.getAciSessionStore().deleteAllSessions(conversationPartnerAddress.serviceId());
} }
if (message.isExpirationUpdate() || message.getBody().isPresent()) { if (message.isExpirationUpdate() || message.getBody().isPresent()) {
if (message.getGroupContext().isPresent()) { if (message.getGroupContext().isPresent()) {
@ -662,7 +669,7 @@ public final class IncomingMessageHandler {
} }
} else if (conversationPartnerAddress != null) { } else if (conversationPartnerAddress != null) {
context.getContactHelper() context.getContactHelper()
.setExpirationTimer(conversationPartnerAddress, message.getExpiresInSeconds()); .setExpirationTimer(conversationPartnerAddress.recipientId(), message.getExpiresInSeconds());
} }
} }
if (!ignoreAttachments) { if (!ignoreAttachments) {
@ -698,7 +705,7 @@ public final class IncomingMessageHandler {
} }
} }
if (message.getProfileKey().isPresent()) { if (message.getProfileKey().isPresent()) {
handleIncomingProfileKey(message.getProfileKey().get(), source); handleIncomingProfileKey(message.getProfileKey().get(), source.recipientId());
} }
if (message.getSticker().isPresent()) { if (message.getSticker().isPresent()) {
final var messageSticker = message.getSticker().get(); final var messageSticker = message.getSticker().get();
@ -769,24 +776,29 @@ public final class IncomingMessageHandler {
this.account.getProfileStore().storeProfileKey(source, profileKey); this.account.getProfileStore().storeProfileKey(source, profileKey);
} }
private Pair<RecipientId, Integer> getSender(SignalServiceEnvelope envelope, SignalServiceContent content) { private DeviceAddress getSender(SignalServiceEnvelope envelope, SignalServiceContent content) {
if (!envelope.isUnidentifiedSender() && envelope.hasSourceUuid()) { if (!envelope.isUnidentifiedSender() && envelope.hasSourceUuid()) {
return new Pair<>(context.getRecipientHelper().resolveRecipient(envelope.getSourceAddress()), return new DeviceAddress(context.getRecipientHelper().resolveRecipient(envelope.getSourceAddress()),
envelope.getSourceAddress().getServiceId(),
envelope.getSourceDevice()); envelope.getSourceDevice());
} else { } else {
return new Pair<>(context.getRecipientHelper().resolveRecipient(content.getSender()), return new DeviceAddress(context.getRecipientHelper().resolveRecipient(content.getSender()),
content.getSender().getServiceId(),
content.getSenderDevice()); content.getSenderDevice());
} }
} }
private RecipientId getDestination(SignalServiceEnvelope envelope) { private DeviceAddress getDestination(SignalServiceEnvelope envelope) {
if (!envelope.hasDestinationUuid()) { if (!envelope.hasDestinationUuid()) {
return account.getSelfRecipientId(); return new DeviceAddress(account.getSelfRecipientId(), account.getAci(), account.getDeviceId());
} }
final var addressOptional = SignalServiceAddress.fromRaw(envelope.getDestinationUuid(), null); final var addressOptional = SignalServiceAddress.fromRaw(envelope.getDestinationUuid(), null);
if (addressOptional.isEmpty()) { if (addressOptional.isEmpty()) {
return account.getSelfRecipientId(); return new DeviceAddress(account.getSelfRecipientId(), account.getAci(), account.getDeviceId());
} }
return context.getRecipientHelper().resolveRecipient(addressOptional.get()); final var address = addressOptional.get();
return new DeviceAddress(context.getRecipientHelper().resolveRecipient(address), address.getServiceId(), 0);
} }
private record DeviceAddress(RecipientId recipientId, ServiceId serviceId, int deviceId) {}
} }

View file

@ -346,7 +346,7 @@ public final class ProfileHelper {
try { try {
logger.trace("Storing identity"); logger.trace("Storing identity");
final var identityKey = new IdentityKey(Base64.getDecoder().decode(encryptedProfile.getIdentityKey())); final var identityKey = new IdentityKey(Base64.getDecoder().decode(encryptedProfile.getIdentityKey()));
account.getIdentityKeyStore().saveIdentity(recipientId, identityKey); account.getIdentityKeyStore().saveIdentity(p.getProfile().getServiceId(), identityKey);
} catch (InvalidKeyException ignored) { } catch (InvalidKeyException ignored) {
logger.warn("Got invalid identity key in profile for {}", logger.warn("Got invalid identity key in profile for {}",
context.getRecipientHelper().resolveSignalServiceAddress(recipientId).getIdentifier()); context.getRecipientHelper().resolveSignalServiceAddress(recipientId).getIdentifier());

View file

@ -486,7 +486,10 @@ public class SendHelper {
continue; continue;
} }
final var identity = account.getIdentityKeyStore().getIdentityInfo(recipientId); final var serviceId = account.getRecipientAddressResolver()
.resolveRecipientAddress(recipientId)
.getServiceId();
final var identity = account.getIdentityKeyStore().getIdentityInfo(serviceId);
if (identity == null || !identity.getTrustLevel().isTrusted()) { if (identity == null || !identity.getTrustLevel().isTrusted()) {
continue; continue;
} }
@ -531,7 +534,7 @@ public class SendHelper {
final var recipientIdList = new ArrayList<>(recipientIds); final var recipientIdList = new ArrayList<>(recipientIds);
long keyCreateTime = account.getSenderKeyStore() long keyCreateTime = account.getSenderKeyStore()
.getCreateTimeForOurKey(account.getSelfRecipientId(), account.getDeviceId(), distributionId); .getCreateTimeForOurKey(account.getAci(), account.getDeviceId(), distributionId);
long keyAge = System.currentTimeMillis() - keyCreateTime; long keyAge = System.currentTimeMillis() - keyCreateTime;
if (keyCreateTime != -1 && keyAge > TimeUnit.DAYS.toMillis(14)) { if (keyCreateTime != -1 && keyAge > TimeUnit.DAYS.toMillis(14)) {
@ -540,7 +543,7 @@ public class SendHelper {
keyCreateTime, keyCreateTime,
keyAge, keyAge,
TimeUnit.MILLISECONDS.toDays(keyAge)); TimeUnit.MILLISECONDS.toDays(keyAge));
account.getSenderKeyStore().deleteOurKey(account.getSelfRecipientId(), distributionId); account.getSenderKeyStore().deleteOurKey(account.getAci(), distributionId);
} }
List<SignalServiceAddress> addresses = recipientIdList.stream() List<SignalServiceAddress> addresses = recipientIdList.stream()
@ -573,11 +576,11 @@ public class SendHelper {
return null; return null;
} catch (NoSessionException e) { } catch (NoSessionException e) {
logger.warn("No session. Falling back to legacy sends.", e); logger.warn("No session. Falling back to legacy sends.", e);
account.getSenderKeyStore().deleteOurKey(account.getSelfRecipientId(), distributionId); account.getSenderKeyStore().deleteOurKey(account.getAci(), distributionId);
return null; return null;
} catch (InvalidKeyException e) { } catch (InvalidKeyException e) {
logger.warn("Invalid key. Falling back to legacy sends.", e); logger.warn("Invalid key. Falling back to legacy sends.", e);
account.getSenderKeyStore().deleteOurKey(account.getSelfRecipientId(), distributionId); account.getSenderKeyStore().deleteOurKey(account.getAci(), distributionId);
return null; return null;
} catch (InvalidRegistrationIdException e) { } catch (InvalidRegistrationIdException e) {
logger.warn("Invalid registrationId. Falling back to legacy sends.", e); logger.warn("Invalid registrationId. Falling back to legacy sends.", e);
@ -685,7 +688,8 @@ public class SendHelper {
} }
if (r.getIdentityFailure() != null) { if (r.getIdentityFailure() != null) {
final var recipientId = context.getRecipientHelper().resolveRecipient(r.getAddress()); final var recipientId = context.getRecipientHelper().resolveRecipient(r.getAddress());
context.getIdentityHelper().handleIdentityFailure(recipientId, r.getIdentityFailure()); context.getIdentityHelper()
.handleIdentityFailure(recipientId, r.getAddress().getServiceId(), r.getIdentityFailure());
} }
} }

View file

@ -144,11 +144,12 @@ public class StorageHelper {
try { try {
logger.trace("Storing identity key {}", recipientId); logger.trace("Storing identity key {}", recipientId);
final var identityKey = new IdentityKey(contactRecord.getIdentityKey().get()); final var identityKey = new IdentityKey(contactRecord.getIdentityKey().get());
account.getIdentityKeyStore().saveIdentity(recipientId, identityKey); account.getIdentityKeyStore().saveIdentity(address.getServiceId(), identityKey);
final var trustLevel = TrustLevel.fromIdentityState(contactRecord.getIdentityState()); final var trustLevel = TrustLevel.fromIdentityState(contactRecord.getIdentityState());
if (trustLevel != null) { if (trustLevel != null) {
account.getIdentityKeyStore().setIdentityTrustLevel(recipientId, identityKey, trustLevel); account.getIdentityKeyStore()
.setIdentityTrustLevel(address.getServiceId(), identityKey, trustLevel);
} }
} catch (InvalidKeyException e) { } catch (InvalidKeyException e) {
logger.warn("Received invalid contact identity key from storage"); logger.warn("Received invalid contact identity key from storage");

View file

@ -132,7 +132,7 @@ public class SyncHelper {
final var contact = contactPair.second(); final var contact = contactPair.second();
final var address = context.getRecipientHelper().resolveSignalServiceAddress(recipientId); final var address = context.getRecipientHelper().resolveSignalServiceAddress(recipientId);
var currentIdentity = account.getIdentityKeyStore().getIdentityInfo(recipientId); var currentIdentity = account.getIdentityKeyStore().getIdentityInfo(address.getServiceId());
VerifiedMessage verifiedMessage = null; VerifiedMessage verifiedMessage = null;
if (currentIdentity != null) { if (currentIdentity != null) {
verifiedMessage = new VerifiedMessage(address, verifiedMessage = new VerifiedMessage(address,
@ -319,8 +319,7 @@ public class SyncHelper {
if (c.getVerified().isPresent()) { if (c.getVerified().isPresent()) {
final var verifiedMessage = c.getVerified().get(); final var verifiedMessage = c.getVerified().get();
account.getIdentityKeyStore() account.getIdentityKeyStore()
.setIdentityTrustLevel(account.getRecipientTrustedResolver() .setIdentityTrustLevel(verifiedMessage.getDestination().getServiceId(),
.resolveRecipientTrusted(verifiedMessage.getDestination()),
verifiedMessage.getIdentityKey(), verifiedMessage.getIdentityKey(),
TrustLevel.fromVerifiedState(verifiedMessage.getVerified())); TrustLevel.fromVerifiedState(verifiedMessage.getVerified()));
} }

View file

@ -22,7 +22,7 @@ import java.sql.SQLException;
public class AccountDatabase extends Database { public class AccountDatabase extends Database {
private final static Logger logger = LoggerFactory.getLogger(AccountDatabase.class); private final static Logger logger = LoggerFactory.getLogger(AccountDatabase.class);
private static final long DATABASE_VERSION = 9; private static final long DATABASE_VERSION = 10;
private AccountDatabase(final HikariDataSource dataSource) { private AccountDatabase(final HikariDataSource dataSource) {
super(logger, DATABASE_VERSION, dataSource); super(logger, DATABASE_VERSION, dataSource);
@ -212,5 +212,81 @@ public class AccountDatabase extends Database {
"""); """);
} }
} }
if (oldVersion < 10) {
logger.debug("Updating database: Key tables on serviceId instead of recipientId");
try (final var statement = connection.createStatement()) {
statement.executeUpdate("""
CREATE TABLE identity2 (
_id INTEGER PRIMARY KEY,
uuid BLOB UNIQUE NOT NULL,
identity_key BLOB NOT NULL,
added_timestamp INTEGER NOT NULL,
trust_level INTEGER NOT NULL
) STRICT;
INSERT INTO identity2 (_id, uuid, identity_key, added_timestamp, trust_level)
SELECT i._id, r.uuid, i.identity_key, i.added_timestamp, i.trust_level
FROM identity i LEFT JOIN recipient r ON i.recipient_id = r._id
WHERE uuid IS NOT NULL;
DROP TABLE identity;
ALTER TABLE identity2 RENAME TO identity;
DROP INDEX msl_recipient_index;
ALTER TABLE message_send_log ADD COLUMN uuid BLOB;
UPDATE message_send_log
SET uuid = r.uuid
FROM message_send_log i, (SELECT _id, uuid FROM recipient) AS r
WHERE i.recipient_id = r._id;
DELETE FROM message_send_log WHERE uuid IS NULL;
ALTER TABLE message_send_log DROP COLUMN recipient_id;
CREATE INDEX msl_recipient_index ON message_send_log (uuid, device_id, content_id);
CREATE TABLE sender_key2 (
_id INTEGER PRIMARY KEY,
uuid BLOB NOT NULL,
device_id INTEGER NOT NULL,
distribution_id BLOB NOT NULL,
record BLOB NOT NULL,
created_timestamp INTEGER NOT NULL,
UNIQUE(uuid, device_id, distribution_id)
) STRICT;
INSERT INTO sender_key2 (_id, uuid, device_id, distribution_id, record, created_timestamp)
SELECT s._id, r.uuid, s.device_id, s.distribution_id, s.record, s.created_timestamp
FROM sender_key s LEFT JOIN recipient r ON s.recipient_id = r._id
WHERE uuid IS NOT NULL;
DROP TABLE sender_key;
ALTER TABLE sender_key2 RENAME TO sender_key;
CREATE TABLE sender_key_shared2 (
_id INTEGER PRIMARY KEY,
uuid BLOB NOT NULL,
device_id INTEGER NOT NULL,
distribution_id BLOB NOT NULL,
timestamp INTEGER NOT NULL,
UNIQUE(uuid, device_id, distribution_id)
) STRICT;
INSERT INTO sender_key_shared2 (_id, uuid, device_id, distribution_id, timestamp)
SELECT s._id, r.uuid, s.device_id, s.distribution_id, s.timestamp
FROM sender_key_shared s LEFT JOIN recipient r ON s.recipient_id = r._id
WHERE uuid IS NOT NULL;
DROP TABLE sender_key_shared;
ALTER TABLE sender_key_shared2 RENAME TO sender_key_shared;
CREATE TABLE session2 (
_id INTEGER PRIMARY KEY,
account_id_type INTEGER NOT NULL,
uuid BLOB NOT NULL,
device_id INTEGER NOT NULL,
record BLOB NOT NULL,
UNIQUE(account_id_type, uuid, device_id)
) STRICT;
INSERT INTO session2 (_id, account_id_type, uuid, device_id, record)
SELECT s._id, s.account_id_type, r.uuid, s.device_id, s.record
FROM session s LEFT JOIN recipient r ON s.recipient_id = r._id
WHERE uuid IS NOT NULL;
DROP TABLE session;
ALTER TABLE session2 RENAME TO session;
""");
}
}
} }
} }

View file

@ -383,6 +383,14 @@ public class SignalAccount implements Closeable {
this.storageManifestVersion = -1; this.storageManifestVersion = -1;
this.setStorageManifest(null); this.setStorageManifest(null);
this.storageKey = null; this.storageKey = null;
final var aciPublicKey = getAciIdentityKeyPair().getPublicKey();
getIdentityKeyStore().saveIdentity(getAci(), aciPublicKey);
getIdentityKeyStore().setIdentityTrustLevel(getAci(), aciPublicKey, TrustLevel.TRUSTED_VERIFIED);
if (getPniIdentityKeyPair() != null) {
final var pniPublicKey = getPniIdentityKeyPair().getPublicKey();
getIdentityKeyStore().saveIdentity(getPni(), pniPublicKey);
getIdentityKeyStore().setIdentityTrustLevel(getPni(), pniPublicKey, TrustLevel.TRUSTED_VERIFIED);
}
} }
private void migrateLegacyConfigs() { private void migrateLegacyConfigs() {
@ -400,21 +408,21 @@ public class SignalAccount implements Closeable {
} }
private void mergeRecipients(RecipientId recipientId, RecipientId toBeMergedRecipientId) { private void mergeRecipients(RecipientId recipientId, RecipientId toBeMergedRecipientId) {
getAciSessionStore().mergeRecipients(recipientId, toBeMergedRecipientId);
getPniSessionStore().mergeRecipients(recipientId, toBeMergedRecipientId);
getIdentityKeyStore().mergeRecipients(recipientId, toBeMergedRecipientId);
getMessageCache().mergeRecipients(recipientId, toBeMergedRecipientId); getMessageCache().mergeRecipients(recipientId, toBeMergedRecipientId);
getGroupStore().mergeRecipients(recipientId, toBeMergedRecipientId); getGroupStore().mergeRecipients(recipientId, toBeMergedRecipientId);
getSenderKeyStore().mergeRecipients(recipientId, toBeMergedRecipientId);
} }
public void removeRecipient(final RecipientId recipientId) { public void removeRecipient(final RecipientId recipientId) {
getAciSessionStore().deleteAllSessions(recipientId);
getPniSessionStore().deleteAllSessions(recipientId);
getIdentityKeyStore().deleteIdentity(recipientId);
getMessageCache().deleteMessages(recipientId);
getSenderKeyStore().deleteAll(recipientId);
getRecipientStore().deleteRecipientData(recipientId); getRecipientStore().deleteRecipientData(recipientId);
getMessageCache().deleteMessages(recipientId);
final var recipientAddress = getRecipientStore().resolveRecipientAddress(recipientId);
if (recipientAddress.uuid().isPresent()) {
final var serviceId = ServiceId.from(recipientAddress.uuid().get());
getAciSessionStore().deleteAllSessions(serviceId);
getPniSessionStore().deleteAllSessions(serviceId);
getIdentityKeyStore().deleteIdentity(serviceId);
getSenderKeyStore().deleteAll(serviceId);
}
} }
public static File getFileName(File dataPath, String account) { public static File getFileName(File dataPath, String account) {
@ -646,12 +654,18 @@ public class SignalAccount implements Closeable {
} }
final var legacySessionsPath = getSessionsPath(dataPath, accountPath); final var legacySessionsPath = getSessionsPath(dataPath, accountPath);
if (legacySessionsPath.exists()) { if (legacySessionsPath.exists()) {
LegacySessionStore.migrate(legacySessionsPath, getRecipientResolver(), getAciSessionStore()); LegacySessionStore.migrate(legacySessionsPath,
getRecipientResolver(),
getRecipientAddressResolver(),
getAciSessionStore());
migratedLegacyConfig = true; migratedLegacyConfig = true;
} }
final var legacyIdentitiesPath = getIdentitiesPath(dataPath, accountPath); final var legacyIdentitiesPath = getIdentitiesPath(dataPath, accountPath);
if (legacyIdentitiesPath.exists()) { if (legacyIdentitiesPath.exists()) {
LegacyIdentityKeyStore.migrate(legacyIdentitiesPath, getRecipientResolver(), getIdentityKeyStore()); LegacyIdentityKeyStore.migrate(legacyIdentitiesPath,
getRecipientResolver(),
getRecipientAddressResolver(),
getIdentityKeyStore());
migratedLegacyConfig = true; migratedLegacyConfig = true;
} }
final var legacySignalProtocolStore = rootNode.hasNonNull("axolotlStore") final var legacySignalProtocolStore = rootNode.hasNonNull("axolotlStore")
@ -672,12 +686,18 @@ public class SignalAccount implements Closeable {
final var legacySenderKeysPath = getSenderKeysPath(dataPath, accountPath); final var legacySenderKeysPath = getSenderKeysPath(dataPath, accountPath);
if (legacySenderKeysPath.exists()) { if (legacySenderKeysPath.exists()) {
LegacySenderKeyRecordStore.migrate(legacySenderKeysPath, getRecipientResolver(), getSenderKeyStore()); LegacySenderKeyRecordStore.migrate(legacySenderKeysPath,
getRecipientResolver(),
getRecipientAddressResolver(),
getSenderKeyStore());
migratedLegacyConfig = true; migratedLegacyConfig = true;
} }
final var legacySenderKeysSharedPath = getSharedSenderKeysFile(dataPath, accountPath); final var legacySenderKeysSharedPath = getSharedSenderKeysFile(dataPath, accountPath);
if (legacySenderKeysSharedPath.exists()) { if (legacySenderKeysSharedPath.exists()) {
LegacySenderKeySharedStore.migrate(legacySenderKeysSharedPath, getRecipientResolver(), getSenderKeyStore()); LegacySenderKeySharedStore.migrate(legacySenderKeysSharedPath,
getRecipientResolver(),
getRecipientAddressResolver(),
getSenderKeyStore());
migratedLegacyConfig = true; migratedLegacyConfig = true;
} }
if (rootNode.hasNonNull("groupStore")) { if (rootNode.hasNonNull("groupStore")) {
@ -770,9 +790,12 @@ public class SignalAccount implements Closeable {
if (legacySignalProtocolStore != null && legacySignalProtocolStore.getLegacyIdentityKeyStore() != null) { if (legacySignalProtocolStore != null && legacySignalProtocolStore.getLegacyIdentityKeyStore() != null) {
logger.debug("Migrating legacy identity session store."); logger.debug("Migrating legacy identity session store.");
for (var identity : legacySignalProtocolStore.getLegacyIdentityKeyStore().getIdentities()) { for (var identity : legacySignalProtocolStore.getLegacyIdentityKeyStore().getIdentities()) {
RecipientId recipientId = getRecipientStore().resolveRecipientTrusted(identity.getAddress()); if (identity.getAddress().uuid().isEmpty()) {
getIdentityKeyStore().saveIdentity(recipientId, identity.getIdentityKey()); continue;
getIdentityKeyStore().setIdentityTrustLevel(recipientId, }
final var serviceId = identity.getAddress().getServiceId();
getIdentityKeyStore().saveIdentity(serviceId, identity.getIdentityKey());
getIdentityKeyStore().setIdentityTrustLevel(serviceId,
identity.getIdentityKey(), identity.getIdentityKey(),
identity.getTrustLevel()); identity.getTrustLevel());
} }
@ -1107,25 +1130,17 @@ public class SignalAccount implements Closeable {
public SessionStore getAciSessionStore() { public SessionStore getAciSessionStore() {
return getOrCreate(() -> aciSessionStore, return getOrCreate(() -> aciSessionStore,
() -> aciSessionStore = new SessionStore(getAccountDatabase(), () -> aciSessionStore = new SessionStore(getAccountDatabase(), ServiceIdType.ACI));
ServiceIdType.ACI,
getRecipientResolver(),
getRecipientIdCreator()));
} }
public SessionStore getPniSessionStore() { public SessionStore getPniSessionStore() {
return getOrCreate(() -> pniSessionStore, return getOrCreate(() -> pniSessionStore,
() -> pniSessionStore = new SessionStore(getAccountDatabase(), () -> pniSessionStore = new SessionStore(getAccountDatabase(), ServiceIdType.PNI));
ServiceIdType.PNI,
getRecipientResolver(),
getRecipientIdCreator()));
} }
public IdentityKeyStore getIdentityKeyStore() { public IdentityKeyStore getIdentityKeyStore() {
return getOrCreate(() -> identityKeyStore, return getOrCreate(() -> identityKeyStore,
() -> identityKeyStore = new IdentityKeyStore(getAccountDatabase(), () -> identityKeyStore = new IdentityKeyStore(getAccountDatabase(), trustNewIdentity));
getRecipientIdCreator(),
trustNewIdentity));
} }
public SignalIdentityKeyStore getAciIdentityKeyStore() { public SignalIdentityKeyStore getAciIdentityKeyStore() {
@ -1207,11 +1222,7 @@ public class SignalAccount implements Closeable {
} }
public SenderKeyStore getSenderKeyStore() { public SenderKeyStore getSenderKeyStore() {
return getOrCreate(() -> senderKeyStore, return getOrCreate(() -> senderKeyStore, () -> senderKeyStore = new SenderKeyStore(getAccountDatabase()));
() -> senderKeyStore = new SenderKeyStore(getAccountDatabase(),
getRecipientAddressResolver(),
getRecipientResolver(),
getRecipientIdCreator()));
} }
public ConfigurationStore getConfigurationStore() { public ConfigurationStore getConfigurationStore() {
@ -1235,7 +1246,7 @@ public class SignalAccount implements Closeable {
public MessageSendLogStore getMessageSendLogStore() { public MessageSendLogStore getMessageSendLogStore() {
return getOrCreate(() -> messageSendLogStore, return getOrCreate(() -> messageSendLogStore,
() -> messageSendLogStore = new MessageSendLogStore(getRecipientResolver(), getAccountDatabase())); () -> messageSendLogStore = new MessageSendLogStore(getAccountDatabase()));
} }
public CredentialsProvider getCredentialsProvider() { public CredentialsProvider getCredentialsProvider() {
@ -1350,6 +1361,9 @@ public class SignalAccount implements Closeable {
public void setPniIdentityKeyPair(final IdentityKeyPair identityKeyPair) { public void setPniIdentityKeyPair(final IdentityKeyPair identityKeyPair) {
pniIdentityKeyPair = identityKeyPair; pniIdentityKeyPair = identityKeyPair;
final var pniPublicKey = getPniIdentityKeyPair().getPublicKey();
getIdentityKeyStore().saveIdentity(getPni(), pniPublicKey);
getIdentityKeyStore().setIdentityTrustLevel(getPni(), pniPublicKey, TrustLevel.TRUSTED_VERIFIED);
save(); save();
} }
@ -1553,10 +1567,13 @@ public class SignalAccount implements Closeable {
getAciSessionStore().archiveAllSessions(); getAciSessionStore().archiveAllSessions();
getPniSessionStore().archiveAllSessions(); getPniSessionStore().archiveAllSessions();
getSenderKeyStore().deleteAll(); getSenderKeyStore().deleteAll();
final var recipientId = getRecipientTrustedResolver().resolveSelfRecipientTrusted(getSelfRecipientAddress()); getRecipientTrustedResolver().resolveSelfRecipientTrusted(getSelfRecipientAddress());
final var publicKey = getAciIdentityKeyPair().getPublicKey(); final var aciPublicKey = getAciIdentityKeyPair().getPublicKey();
getIdentityKeyStore().saveIdentity(recipientId, publicKey); getIdentityKeyStore().saveIdentity(getAci(), aciPublicKey);
getIdentityKeyStore().setIdentityTrustLevel(recipientId, publicKey, TrustLevel.TRUSTED_VERIFIED); getIdentityKeyStore().setIdentityTrustLevel(getAci(), aciPublicKey, TrustLevel.TRUSTED_VERIFIED);
final var pniPublicKey = getPniIdentityKeyPair().getPublicKey();
getIdentityKeyStore().saveIdentity(getPni(), pniPublicKey);
getIdentityKeyStore().setIdentityTrustLevel(getPni(), pniPublicKey, TrustLevel.TRUSTED_VERIFIED);
} }
public void deleteAccountData() throws IOException { public void deleteAccountData() throws IOException {

View file

@ -1,27 +1,27 @@
package org.asamk.signal.manager.storage.identities; package org.asamk.signal.manager.storage.identities;
import org.asamk.signal.manager.api.TrustLevel; import org.asamk.signal.manager.api.TrustLevel;
import org.asamk.signal.manager.storage.recipients.RecipientId;
import org.signal.libsignal.protocol.IdentityKey; import org.signal.libsignal.protocol.IdentityKey;
import org.whispersystems.signalservice.api.push.ServiceId;
public class IdentityInfo { public class IdentityInfo {
private final RecipientId recipientId; private final ServiceId serviceId;
private final IdentityKey identityKey; private final IdentityKey identityKey;
private final TrustLevel trustLevel; private final TrustLevel trustLevel;
private final long addedTimestamp; private final long addedTimestamp;
IdentityInfo( IdentityInfo(
final RecipientId recipientId, IdentityKey identityKey, TrustLevel trustLevel, long addedTimestamp final ServiceId serviceId, IdentityKey identityKey, TrustLevel trustLevel, long addedTimestamp
) { ) {
this.recipientId = recipientId; this.serviceId = serviceId;
this.identityKey = identityKey; this.identityKey = identityKey;
this.trustLevel = trustLevel; this.trustLevel = trustLevel;
this.addedTimestamp = addedTimestamp; this.addedTimestamp = addedTimestamp;
} }
public RecipientId getRecipientId() { public ServiceId getServiceId() {
return recipientId; return serviceId;
} }
public IdentityKey getIdentityKey() { public IdentityKey getIdentityKey() {

View file

@ -3,13 +3,12 @@ package org.asamk.signal.manager.storage.identities;
import org.asamk.signal.manager.api.TrustLevel; import org.asamk.signal.manager.api.TrustLevel;
import org.asamk.signal.manager.storage.Database; import org.asamk.signal.manager.storage.Database;
import org.asamk.signal.manager.storage.Utils; import org.asamk.signal.manager.storage.Utils;
import org.asamk.signal.manager.storage.recipients.RecipientId;
import org.asamk.signal.manager.storage.recipients.RecipientIdCreator;
import org.signal.libsignal.protocol.IdentityKey; import org.signal.libsignal.protocol.IdentityKey;
import org.signal.libsignal.protocol.InvalidKeyException; import org.signal.libsignal.protocol.InvalidKeyException;
import org.signal.libsignal.protocol.state.IdentityKeyStore.Direction; import org.signal.libsignal.protocol.state.IdentityKeyStore.Direction;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.whispersystems.signalservice.api.push.ServiceId;
import java.sql.Connection; import java.sql.Connection;
import java.sql.ResultSet; import java.sql.ResultSet;
@ -25,9 +24,8 @@ public class IdentityKeyStore {
private final static Logger logger = LoggerFactory.getLogger(IdentityKeyStore.class); private final static Logger logger = LoggerFactory.getLogger(IdentityKeyStore.class);
private static final String TABLE_IDENTITY = "identity"; private static final String TABLE_IDENTITY = "identity";
private final Database database; private final Database database;
private final RecipientIdCreator recipientIdCreator;
private final TrustNewIdentity trustNewIdentity; private final TrustNewIdentity trustNewIdentity;
private final PublishSubject<RecipientId> identityChanges = PublishSubject.create(); private final PublishSubject<ServiceId> identityChanges = PublishSubject.create();
private boolean isRetryingDecryption = false; private boolean isRetryingDecryption = false;
@ -37,42 +35,37 @@ public class IdentityKeyStore {
statement.executeUpdate(""" statement.executeUpdate("""
CREATE TABLE identity ( CREATE TABLE identity (
_id INTEGER PRIMARY KEY, _id INTEGER PRIMARY KEY,
recipient_id INTEGER UNIQUE NOT NULL REFERENCES recipient (_id) ON DELETE CASCADE, uuid BLOB UNIQUE NOT NULL,
identity_key BLOB NOT NULL, identity_key BLOB NOT NULL,
added_timestamp INTEGER NOT NULL, added_timestamp INTEGER NOT NULL,
trust_level INTEGER NOT NULL trust_level INTEGER NOT NULL
); ) STRICT;
"""); """);
} }
} }
public IdentityKeyStore( public IdentityKeyStore(final Database database, final TrustNewIdentity trustNewIdentity) {
final Database database,
final RecipientIdCreator recipientIdCreator,
final TrustNewIdentity trustNewIdentity
) {
this.database = database; this.database = database;
this.recipientIdCreator = recipientIdCreator;
this.trustNewIdentity = trustNewIdentity; this.trustNewIdentity = trustNewIdentity;
} }
public Observable<RecipientId> getIdentityChanges() { public Observable<ServiceId> getIdentityChanges() {
return identityChanges; return identityChanges;
} }
public boolean saveIdentity(final RecipientId recipientId, final IdentityKey identityKey) { public boolean saveIdentity(final ServiceId serviceId, final IdentityKey identityKey) {
if (isRetryingDecryption) { if (isRetryingDecryption) {
return false; return false;
} }
try (final var connection = database.getConnection()) { try (final var connection = database.getConnection()) {
final var identityInfo = loadIdentity(connection, recipientId); final var identityInfo = loadIdentity(connection, serviceId);
if (identityInfo != null && identityInfo.getIdentityKey().equals(identityKey)) { if (identityInfo != null && identityInfo.getIdentityKey().equals(identityKey)) {
// Identity already exists, not updating the trust level // Identity already exists, not updating the trust level
logger.trace("Not storing new identity for recipient {}, identity already stored", recipientId); logger.trace("Not storing new identity for recipient {}, identity already stored", serviceId);
return false; return false;
} }
saveNewIdentity(connection, recipientId, identityKey, identityInfo == null); saveNewIdentity(connection, serviceId, identityKey, identityInfo == null);
return true; return true;
} catch (SQLException e) { } catch (SQLException e) {
throw new RuntimeException("Failed update identity store", e); throw new RuntimeException("Failed update identity store", e);
@ -83,24 +76,24 @@ public class IdentityKeyStore {
isRetryingDecryption = retryingDecryption; isRetryingDecryption = retryingDecryption;
} }
public boolean setIdentityTrustLevel(RecipientId recipientId, IdentityKey identityKey, TrustLevel trustLevel) { public boolean setIdentityTrustLevel(ServiceId serviceId, IdentityKey identityKey, TrustLevel trustLevel) {
try (final var connection = database.getConnection()) { try (final var connection = database.getConnection()) {
final var identityInfo = loadIdentity(connection, recipientId); final var identityInfo = loadIdentity(connection, serviceId);
if (identityInfo == null) { if (identityInfo == null) {
logger.debug("Not updating trust level for recipient {}, identity not found", recipientId); logger.debug("Not updating trust level for recipient {}, identity not found", serviceId);
return false; return false;
} }
if (!identityInfo.getIdentityKey().equals(identityKey)) { if (!identityInfo.getIdentityKey().equals(identityKey)) {
logger.debug("Not updating trust level for recipient {}, different identity found", recipientId); logger.debug("Not updating trust level for recipient {}, different identity found", serviceId);
return false; return false;
} }
if (identityInfo.getTrustLevel() == trustLevel) { if (identityInfo.getTrustLevel() == trustLevel) {
logger.trace("Not updating trust level for recipient {}, trust level already matches", recipientId); logger.trace("Not updating trust level for recipient {}, trust level already matches", serviceId);
return false; return false;
} }
logger.debug("Updating trust level for recipient {} with trust {}", recipientId, trustLevel); logger.debug("Updating trust level for recipient {} with trust {}", serviceId, trustLevel);
final var newIdentityInfo = new IdentityInfo(recipientId, final var newIdentityInfo = new IdentityInfo(serviceId,
identityKey, identityKey,
trustLevel, trustLevel,
identityInfo.getDateAddedTimestamp()); identityInfo.getDateAddedTimestamp());
@ -111,41 +104,41 @@ public class IdentityKeyStore {
} }
} }
public boolean isTrustedIdentity(RecipientId recipientId, IdentityKey identityKey, Direction direction) { public boolean isTrustedIdentity(ServiceId serviceId, IdentityKey identityKey, Direction direction) {
if (trustNewIdentity == TrustNewIdentity.ALWAYS) { if (trustNewIdentity == TrustNewIdentity.ALWAYS) {
return true; return true;
} }
try (final var connection = database.getConnection()) { try (final var connection = database.getConnection()) {
// TODO implement possibility for different handling of incoming/outgoing trust decisions // TODO implement possibility for different handling of incoming/outgoing trust decisions
var identityInfo = loadIdentity(connection, recipientId); var identityInfo = loadIdentity(connection, serviceId);
if (identityInfo == null) { if (identityInfo == null) {
logger.debug("Initial identity found for {}, saving.", recipientId); logger.debug("Initial identity found for {}, saving.", serviceId);
saveNewIdentity(connection, recipientId, identityKey, true); saveNewIdentity(connection, serviceId, identityKey, true);
identityInfo = loadIdentity(connection, recipientId); identityInfo = loadIdentity(connection, serviceId);
} else if (!identityInfo.getIdentityKey().equals(identityKey)) { } else if (!identityInfo.getIdentityKey().equals(identityKey)) {
// Identity found, but different // Identity found, but different
if (direction == Direction.SENDING) { if (direction == Direction.SENDING) {
logger.debug("Changed identity found for {}, saving.", recipientId); logger.debug("Changed identity found for {}, saving.", serviceId);
saveNewIdentity(connection, recipientId, identityKey, false); saveNewIdentity(connection, serviceId, identityKey, false);
identityInfo = loadIdentity(connection, recipientId); identityInfo = loadIdentity(connection, serviceId);
} else { } else {
logger.trace("Trusting identity for {} for {}: {}", recipientId, direction, false); logger.trace("Trusting identity for {} for {}: {}", serviceId, direction, false);
return false; return false;
} }
} }
final var isTrusted = identityInfo != null && identityInfo.isTrusted(); final var isTrusted = identityInfo != null && identityInfo.isTrusted();
logger.trace("Trusting identity for {} for {}: {}", recipientId, direction, isTrusted); logger.trace("Trusting identity for {} for {}: {}", serviceId, direction, isTrusted);
return isTrusted; return isTrusted;
} catch (SQLException e) { } catch (SQLException e) {
throw new RuntimeException("Failed read from identity store", e); throw new RuntimeException("Failed read from identity store", e);
} }
} }
public IdentityInfo getIdentityInfo(RecipientId recipientId) { public IdentityInfo getIdentityInfo(ServiceId serviceId) {
try (final var connection = database.getConnection()) { try (final var connection = database.getConnection()) {
return loadIdentity(connection, recipientId); return loadIdentity(connection, serviceId);
} catch (SQLException e) { } catch (SQLException e) {
throw new RuntimeException("Failed read from identity store", e); throw new RuntimeException("Failed read from identity store", e);
} }
@ -155,7 +148,7 @@ public class IdentityKeyStore {
try (final var connection = database.getConnection()) { try (final var connection = database.getConnection()) {
final var sql = ( final var sql = (
""" """
SELECT i.recipient_id, i.identity_key, i.added_timestamp, i.trust_level SELECT i.uuid, i.identity_key, i.added_timestamp, i.trust_level
FROM %s AS i FROM %s AS i
""" """
).formatted(TABLE_IDENTITY); ).formatted(TABLE_IDENTITY);
@ -167,32 +160,9 @@ public class IdentityKeyStore {
} }
} }
public void mergeRecipients(final RecipientId recipientId, final RecipientId toBeMergedRecipientId) { public void deleteIdentity(final ServiceId serviceId) {
try (final var connection = database.getConnection()) { try (final var connection = database.getConnection()) {
connection.setAutoCommit(false); deleteIdentity(connection, serviceId);
final var sql = (
"""
UPDATE OR IGNORE %s
SET recipient_id = ?
WHERE recipient_id = ?
"""
).formatted(TABLE_IDENTITY);
try (final var statement = connection.prepareStatement(sql)) {
statement.setLong(1, recipientId.id());
statement.setLong(2, toBeMergedRecipientId.id());
statement.executeUpdate();
}
deleteIdentity(connection, toBeMergedRecipientId);
connection.commit();
} catch (SQLException e) {
throw new RuntimeException("Failed update identity store", e);
}
}
public void deleteIdentity(final RecipientId recipientId) {
try (final var connection = database.getConnection()) {
deleteIdentity(connection, recipientId);
} catch (SQLException e) { } catch (SQLException e) {
throw new RuntimeException("Failed update identity store", e); throw new RuntimeException("Failed update identity store", e);
} }
@ -214,49 +184,49 @@ public class IdentityKeyStore {
} }
private IdentityInfo loadIdentity( private IdentityInfo loadIdentity(
final Connection connection, final RecipientId recipientId final Connection connection, final ServiceId serviceId
) throws SQLException { ) throws SQLException {
final var sql = ( final var sql = (
""" """
SELECT i.recipient_id, i.identity_key, i.added_timestamp, i.trust_level SELECT i.uuid, i.identity_key, i.added_timestamp, i.trust_level
FROM %s AS i FROM %s AS i
WHERE i.recipient_id = ? WHERE i.uuid = ?
""" """
).formatted(TABLE_IDENTITY); ).formatted(TABLE_IDENTITY);
try (final var statement = connection.prepareStatement(sql)) { try (final var statement = connection.prepareStatement(sql)) {
statement.setLong(1, recipientId.id()); statement.setBytes(1, serviceId.toByteArray());
return Utils.executeQueryForOptional(statement, this::getIdentityInfoFromResultSet).orElse(null); return Utils.executeQueryForOptional(statement, this::getIdentityInfoFromResultSet).orElse(null);
} }
} }
private void saveNewIdentity( private void saveNewIdentity(
final Connection connection, final Connection connection,
final RecipientId recipientId, final ServiceId serviceId,
final IdentityKey identityKey, final IdentityKey identityKey,
final boolean firstIdentity final boolean firstIdentity
) throws SQLException { ) throws SQLException {
final var trustLevel = trustNewIdentity == TrustNewIdentity.ALWAYS || ( final var trustLevel = trustNewIdentity == TrustNewIdentity.ALWAYS || (
trustNewIdentity == TrustNewIdentity.ON_FIRST_USE && firstIdentity trustNewIdentity == TrustNewIdentity.ON_FIRST_USE && firstIdentity
) ? TrustLevel.TRUSTED_UNVERIFIED : TrustLevel.UNTRUSTED; ) ? TrustLevel.TRUSTED_UNVERIFIED : TrustLevel.UNTRUSTED;
logger.debug("Storing new identity for recipient {} with trust {}", recipientId, trustLevel); logger.debug("Storing new identity for recipient {} with trust {}", serviceId, trustLevel);
final var newIdentityInfo = new IdentityInfo(recipientId, identityKey, trustLevel, System.currentTimeMillis()); final var newIdentityInfo = new IdentityInfo(serviceId, identityKey, trustLevel, System.currentTimeMillis());
storeIdentity(connection, newIdentityInfo); storeIdentity(connection, newIdentityInfo);
identityChanges.onNext(recipientId); identityChanges.onNext(serviceId);
} }
private void storeIdentity(final Connection connection, final IdentityInfo identityInfo) throws SQLException { private void storeIdentity(final Connection connection, final IdentityInfo identityInfo) throws SQLException {
logger.trace("Storing identity info for {}, trust: {}, added: {}", logger.trace("Storing identity info for {}, trust: {}, added: {}",
identityInfo.getRecipientId(), identityInfo.getServiceId(),
identityInfo.getTrustLevel(), identityInfo.getTrustLevel(),
identityInfo.getDateAddedTimestamp()); identityInfo.getDateAddedTimestamp());
final var sql = ( final var sql = (
""" """
INSERT OR REPLACE INTO %s (recipient_id, identity_key, added_timestamp, trust_level) INSERT OR REPLACE INTO %s (uuid, identity_key, added_timestamp, trust_level)
VALUES (?, ?, ?, ?) VALUES (?, ?, ?, ?)
""" """
).formatted(TABLE_IDENTITY); ).formatted(TABLE_IDENTITY);
try (final var statement = connection.prepareStatement(sql)) { try (final var statement = connection.prepareStatement(sql)) {
statement.setLong(1, identityInfo.getRecipientId().id()); statement.setBytes(1, identityInfo.getServiceId().toByteArray());
statement.setBytes(2, identityInfo.getIdentityKey().serialize()); statement.setBytes(2, identityInfo.getIdentityKey().serialize());
statement.setLong(3, identityInfo.getDateAddedTimestamp()); statement.setLong(3, identityInfo.getDateAddedTimestamp());
statement.setInt(4, identityInfo.getTrustLevel().ordinal()); statement.setInt(4, identityInfo.getTrustLevel().ordinal());
@ -264,27 +234,27 @@ public class IdentityKeyStore {
} }
} }
private void deleteIdentity(final Connection connection, final RecipientId recipientId) throws SQLException { private void deleteIdentity(final Connection connection, final ServiceId serviceId) throws SQLException {
final var sql = ( final var sql = (
""" """
DELETE FROM %s AS i DELETE FROM %s AS i
WHERE i.recipient_id = ? WHERE i.uuid = ?
""" """
).formatted(TABLE_IDENTITY); ).formatted(TABLE_IDENTITY);
try (final var statement = connection.prepareStatement(sql)) { try (final var statement = connection.prepareStatement(sql)) {
statement.setLong(1, recipientId.id()); statement.setBytes(1, serviceId.toByteArray());
statement.executeUpdate(); statement.executeUpdate();
} }
} }
private IdentityInfo getIdentityInfoFromResultSet(ResultSet resultSet) throws SQLException { private IdentityInfo getIdentityInfoFromResultSet(ResultSet resultSet) throws SQLException {
try { try {
final var recipientId = recipientIdCreator.create(resultSet.getLong("recipient_id")); final var serviceId = ServiceId.parseOrThrow(resultSet.getBytes("uuid"));
final var id = new IdentityKey(resultSet.getBytes("identity_key")); final var id = new IdentityKey(resultSet.getBytes("identity_key"));
final var trustLevel = TrustLevel.fromInt(resultSet.getInt("trust_level")); final var trustLevel = TrustLevel.fromInt(resultSet.getInt("trust_level"));
final var added = resultSet.getLong("added_timestamp"); final var added = resultSet.getLong("added_timestamp");
return new IdentityInfo(recipientId, id, trustLevel, added); return new IdentityInfo(serviceId, id, trustLevel, added);
} catch (InvalidKeyException e) { } catch (InvalidKeyException e) {
logger.warn("Failed to load identity key, resetting: {}", e.getMessage()); logger.warn("Failed to load identity key, resetting: {}", e.getMessage());
return null; return null;

View file

@ -3,6 +3,7 @@ package org.asamk.signal.manager.storage.identities;
import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.ObjectMapper;
import org.asamk.signal.manager.api.TrustLevel; import org.asamk.signal.manager.api.TrustLevel;
import org.asamk.signal.manager.helper.RecipientAddressResolver;
import org.asamk.signal.manager.storage.recipients.RecipientId; import org.asamk.signal.manager.storage.recipients.RecipientId;
import org.asamk.signal.manager.storage.recipients.RecipientResolver; import org.asamk.signal.manager.storage.recipients.RecipientResolver;
import org.asamk.signal.manager.util.IOUtils; import org.asamk.signal.manager.util.IOUtils;
@ -27,16 +28,21 @@ public class LegacyIdentityKeyStore {
private static final ObjectMapper objectMapper = org.asamk.signal.manager.storage.Utils.createStorageObjectMapper(); private static final ObjectMapper objectMapper = org.asamk.signal.manager.storage.Utils.createStorageObjectMapper();
public static void migrate( public static void migrate(
final File identitiesPath, final RecipientResolver resolver, final IdentityKeyStore identityKeyStore final File identitiesPath,
final RecipientResolver resolver,
final RecipientAddressResolver addressResolver,
final IdentityKeyStore identityKeyStore
) { ) {
final var identities = getIdentities(identitiesPath, resolver); final var identities = getIdentities(identitiesPath, resolver, addressResolver);
identityKeyStore.addLegacyIdentities(identities); identityKeyStore.addLegacyIdentities(identities);
removeIdentityFiles(identitiesPath); removeIdentityFiles(identitiesPath);
} }
static final Pattern identityFileNamePattern = Pattern.compile("(\\d+)"); static final Pattern identityFileNamePattern = Pattern.compile("(\\d+)");
private static List<IdentityInfo> getIdentities(final File identitiesPath, final RecipientResolver resolver) { private static List<IdentityInfo> getIdentities(
final File identitiesPath, final RecipientResolver resolver, final RecipientAddressResolver addressResolver
) {
final var files = identitiesPath.listFiles(); final var files = identitiesPath.listFiles();
if (files == null) { if (files == null) {
return List.of(); return List.of();
@ -45,7 +51,7 @@ public class LegacyIdentityKeyStore {
.filter(f -> identityFileNamePattern.matcher(f.getName()).matches()) .filter(f -> identityFileNamePattern.matcher(f.getName()).matches())
.map(f -> resolver.resolveRecipient(Long.parseLong(f.getName()))) .map(f -> resolver.resolveRecipient(Long.parseLong(f.getName())))
.filter(Objects::nonNull) .filter(Objects::nonNull)
.map(recipientId -> loadIdentityLocked(recipientId, identitiesPath)) .map(recipientId -> loadIdentityLocked(recipientId, addressResolver, identitiesPath))
.filter(Objects::nonNull) .filter(Objects::nonNull)
.toList(); .toList();
} }
@ -59,7 +65,9 @@ public class LegacyIdentityKeyStore {
return new File(identitiesPath, String.valueOf(recipientId.id())); return new File(identitiesPath, String.valueOf(recipientId.id()));
} }
private static IdentityInfo loadIdentityLocked(final RecipientId recipientId, final File identitiesPath) { private static IdentityInfo loadIdentityLocked(
final RecipientId recipientId, RecipientAddressResolver addressResolver, final File identitiesPath
) {
final var file = getIdentityFile(recipientId, identitiesPath); final var file = getIdentityFile(recipientId, identitiesPath);
if (!file.exists()) { if (!file.exists()) {
return null; return null;
@ -71,7 +79,8 @@ public class LegacyIdentityKeyStore {
var trustLevel = TrustLevel.fromInt(storage.trustLevel()); var trustLevel = TrustLevel.fromInt(storage.trustLevel());
var added = storage.addedTimestamp(); var added = storage.addedTimestamp();
return new IdentityInfo(recipientId, id, trustLevel, added); final var serviceId = addressResolver.resolveRecipientAddress(recipientId).getServiceId();
return new IdentityInfo(serviceId, id, trustLevel, added);
} catch (IOException | InvalidKeyException e) { } catch (IOException | InvalidKeyException e) {
logger.warn("Failed to load identity key: {}", e.getMessage()); logger.warn("Failed to load identity key: {}", e.getMessage());
return null; return null;

View file

@ -1,16 +1,15 @@
package org.asamk.signal.manager.storage.identities; package org.asamk.signal.manager.storage.identities;
import org.asamk.signal.manager.storage.recipients.RecipientId;
import org.asamk.signal.manager.storage.recipients.RecipientResolver; import org.asamk.signal.manager.storage.recipients.RecipientResolver;
import org.signal.libsignal.protocol.IdentityKey; import org.signal.libsignal.protocol.IdentityKey;
import org.signal.libsignal.protocol.IdentityKeyPair; import org.signal.libsignal.protocol.IdentityKeyPair;
import org.signal.libsignal.protocol.SignalProtocolAddress; import org.signal.libsignal.protocol.SignalProtocolAddress;
import org.whispersystems.signalservice.api.push.ServiceId;
import java.util.function.Supplier; import java.util.function.Supplier;
public class SignalIdentityKeyStore implements org.signal.libsignal.protocol.state.IdentityKeyStore { public class SignalIdentityKeyStore implements org.signal.libsignal.protocol.state.IdentityKeyStore {
private final RecipientResolver resolver;
private final Supplier<IdentityKeyPair> identityKeyPairSupplier; private final Supplier<IdentityKeyPair> identityKeyPairSupplier;
private final int localRegistrationId; private final int localRegistrationId;
private final IdentityKeyStore identityKeyStore; private final IdentityKeyStore identityKeyStore;
@ -21,7 +20,6 @@ public class SignalIdentityKeyStore implements org.signal.libsignal.protocol.sta
final int localRegistrationId, final int localRegistrationId,
final IdentityKeyStore identityKeyStore final IdentityKeyStore identityKeyStore
) { ) {
this.resolver = resolver;
this.identityKeyPairSupplier = identityKeyPairSupplier; this.identityKeyPairSupplier = identityKeyPairSupplier;
this.localRegistrationId = localRegistrationId; this.localRegistrationId = localRegistrationId;
this.identityKeyStore = identityKeyStore; this.identityKeyStore = identityKeyStore;
@ -39,29 +37,22 @@ public class SignalIdentityKeyStore implements org.signal.libsignal.protocol.sta
@Override @Override
public boolean saveIdentity(SignalProtocolAddress address, IdentityKey identityKey) { public boolean saveIdentity(SignalProtocolAddress address, IdentityKey identityKey) {
final var recipientId = resolveRecipient(address.getName()); final var serviceId = ServiceId.parseOrThrow(address.getName());
return identityKeyStore.saveIdentity(recipientId, identityKey); return identityKeyStore.saveIdentity(serviceId, identityKey);
} }
@Override @Override
public boolean isTrustedIdentity(SignalProtocolAddress address, IdentityKey identityKey, Direction direction) { public boolean isTrustedIdentity(SignalProtocolAddress address, IdentityKey identityKey, Direction direction) {
var recipientId = resolveRecipient(address.getName()); final var serviceId = ServiceId.parseOrThrow(address.getName());
return identityKeyStore.isTrustedIdentity(recipientId, identityKey, direction); return identityKeyStore.isTrustedIdentity(serviceId, identityKey, direction);
} }
@Override @Override
public IdentityKey getIdentity(SignalProtocolAddress address) { public IdentityKey getIdentity(SignalProtocolAddress address) {
var recipientId = resolveRecipient(address.getName()); final var serviceId = ServiceId.parseOrThrow(address.getName());
final var identityInfo = identityKeyStore.getIdentityInfo(recipientId); final var identityInfo = identityKeyStore.getIdentityInfo(serviceId);
return identityInfo == null ? null : identityInfo.getIdentityKey(); return identityInfo == null ? null : identityInfo.getIdentityKey();
} }
/**
* @param identifier can be either a serialized uuid or an e164 phone number
*/
private RecipientId resolveRecipient(String identifier) {
return resolver.resolveRecipient(identifier);
}
} }

View file

@ -35,6 +35,10 @@ public record RecipientAddress(Optional<UUID> uuid, Optional<String> number) {
this(Optional.of(uuid), Optional.empty()); this(Optional.of(uuid), Optional.empty());
} }
public ServiceId getServiceId() {
return ServiceId.from(uuid.orElse(UNKNOWN_UUID));
}
public String getIdentifier() { public String getIdentifier() {
if (uuid.isPresent()) { if (uuid.isPresent()) {
return uuid.get().toString(); return uuid.get().toString();
@ -62,6 +66,6 @@ public record RecipientAddress(Optional<UUID> uuid, Optional<String> number) {
} }
public SignalServiceAddress toSignalServiceAddress() { public SignalServiceAddress toSignalServiceAddress() {
return new SignalServiceAddress(ServiceId.from(uuid.orElse(UNKNOWN_UUID)), number); return new SignalServiceAddress(getServiceId(), number);
} }
} }

View file

@ -4,14 +4,13 @@ import org.asamk.signal.manager.groups.GroupId;
import org.asamk.signal.manager.groups.GroupUtils; import org.asamk.signal.manager.groups.GroupUtils;
import org.asamk.signal.manager.storage.Database; import org.asamk.signal.manager.storage.Database;
import org.asamk.signal.manager.storage.Utils; import org.asamk.signal.manager.storage.Utils;
import org.asamk.signal.manager.storage.recipients.RecipientId;
import org.asamk.signal.manager.storage.recipients.RecipientResolver;
import org.signal.libsignal.zkgroup.InvalidInputException; import org.signal.libsignal.zkgroup.InvalidInputException;
import org.signal.libsignal.zkgroup.groups.GroupMasterKey; import org.signal.libsignal.zkgroup.groups.GroupMasterKey;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.whispersystems.signalservice.api.crypto.ContentHint; import org.whispersystems.signalservice.api.crypto.ContentHint;
import org.whispersystems.signalservice.api.messages.SendMessageResult; import org.whispersystems.signalservice.api.messages.SendMessageResult;
import org.whispersystems.signalservice.api.push.ServiceId;
import org.whispersystems.signalservice.internal.push.SignalServiceProtos; import org.whispersystems.signalservice.internal.push.SignalServiceProtos;
import java.io.IOException; import java.io.IOException;
@ -32,14 +31,10 @@ public class MessageSendLogStore implements AutoCloseable {
private static final Duration LOG_DURATION = Duration.ofDays(1); private static final Duration LOG_DURATION = Duration.ofDays(1);
private final RecipientResolver recipientResolver;
private final Database database; private final Database database;
private final Thread cleanupThread; private final Thread cleanupThread;
public MessageSendLogStore( public MessageSendLogStore(final Database database) {
final RecipientResolver recipientResolver, final Database database
) {
this.recipientResolver = recipientResolver;
this.database = database; this.database = database;
this.cleanupThread = new Thread(() -> { this.cleanupThread = new Thread(() -> {
try { try {
@ -69,9 +64,9 @@ public class MessageSendLogStore implements AutoCloseable {
CREATE TABLE message_send_log ( CREATE TABLE message_send_log (
_id INTEGER PRIMARY KEY, _id INTEGER PRIMARY KEY,
content_id INTEGER NOT NULL REFERENCES message_send_log_content (_id) ON DELETE CASCADE, content_id INTEGER NOT NULL REFERENCES message_send_log_content (_id) ON DELETE CASCADE,
recipient_id INTEGER NOT NULL REFERENCES recipient (_id) ON DELETE CASCADE, uuid BLOB NOT NULL,
device_id INTEGER NOT NULL device_id INTEGER NOT NULL
); ) STRICT;
CREATE TABLE message_send_log_content ( CREATE TABLE message_send_log_content (
_id INTEGER PRIMARY KEY, _id INTEGER PRIMARY KEY,
group_id BLOB, group_id BLOB,
@ -81,26 +76,26 @@ public class MessageSendLogStore implements AutoCloseable {
urgent BOOLEAN NOT NULL urgent BOOLEAN NOT NULL
); );
CREATE INDEX mslc_timestamp_index ON message_send_log_content (timestamp); CREATE INDEX mslc_timestamp_index ON message_send_log_content (timestamp);
CREATE INDEX msl_recipient_index ON message_send_log (recipient_id, device_id, content_id); CREATE INDEX msl_recipient_index ON message_send_log (uuid, device_id, content_id);
CREATE INDEX msl_content_index ON message_send_log (content_id); CREATE INDEX msl_content_index ON message_send_log (content_id);
"""); """);
} }
} }
public List<MessageSendLogEntry> findMessages( public List<MessageSendLogEntry> findMessages(
final RecipientId recipientId, final int deviceId, final long timestamp, final boolean isSenderKey final ServiceId serviceId, final int deviceId, final long timestamp, final boolean isSenderKey
) { ) {
final var sql = """ final var sql = """
SELECT group_id, content, content_hint SELECT group_id, content, content_hint
FROM %s l FROM %s l
INNER JOIN %s lc ON l.content_id = lc._id INNER JOIN %s lc ON l.content_id = lc._id
WHERE l.recipient_id = ? AND l.device_id = ? AND lc.timestamp = ? WHERE l.uuid = ? AND l.device_id = ? AND lc.timestamp = ?
""".formatted(TABLE_MESSAGE_SEND_LOG, TABLE_MESSAGE_SEND_LOG_CONTENT); """.formatted(TABLE_MESSAGE_SEND_LOG, TABLE_MESSAGE_SEND_LOG_CONTENT);
try (final var connection = database.getConnection()) { try (final var connection = database.getConnection()) {
deleteOutdatedEntries(connection); deleteOutdatedEntries(connection);
try (final var statement = connection.prepareStatement(sql)) { try (final var statement = connection.prepareStatement(sql)) {
statement.setLong(1, recipientId.id()); statement.setBytes(1, serviceId.toByteArray());
statement.setInt(2, deviceId); statement.setInt(2, deviceId);
statement.setLong(3, timestamp); statement.setLong(3, timestamp);
try (var result = Utils.executeQueryForStream(statement, this::getMessageSendLogEntryFromResultSet)) { try (var result = Utils.executeQueryForStream(statement, this::getMessageSendLogEntryFromResultSet)) {
@ -189,16 +184,16 @@ public class MessageSendLogStore implements AutoCloseable {
} }
} }
public void deleteEntryForRecipientNonGroup(long sentTimestamp, RecipientId recipientId) { public void deleteEntryForRecipientNonGroup(long sentTimestamp, ServiceId serviceId) {
final var sql = """ final var sql = """
DELETE FROM %s AS lc DELETE FROM %s AS lc
WHERE lc.timestamp = ? AND lc.group_id IS NULL AND lc._id IN (SELECT content_id FROM %s l WHERE l.recipient_id = ?) WHERE lc.timestamp = ? AND lc.group_id IS NULL AND lc._id IN (SELECT content_id FROM %s l WHERE l.uuid = ?)
""".formatted(TABLE_MESSAGE_SEND_LOG_CONTENT, TABLE_MESSAGE_SEND_LOG); """.formatted(TABLE_MESSAGE_SEND_LOG_CONTENT, TABLE_MESSAGE_SEND_LOG);
try (final var connection = database.getConnection()) { try (final var connection = database.getConnection()) {
connection.setAutoCommit(false); connection.setAutoCommit(false);
try (final var statement = connection.prepareStatement(sql)) { try (final var statement = connection.prepareStatement(sql)) {
statement.setLong(1, sentTimestamp); statement.setLong(1, sentTimestamp);
statement.setLong(2, recipientId.id()); statement.setBytes(2, serviceId.toByteArray());
statement.executeUpdate(); statement.executeUpdate();
} }
@ -209,21 +204,21 @@ public class MessageSendLogStore implements AutoCloseable {
} }
} }
public void deleteEntryForRecipient(long sentTimestamp, RecipientId recipientId, int deviceId) { public void deleteEntryForRecipient(long sentTimestamp, ServiceId serviceId, int deviceId) {
deleteEntriesForRecipient(List.of(sentTimestamp), recipientId, deviceId); deleteEntriesForRecipient(List.of(sentTimestamp), serviceId, deviceId);
} }
public void deleteEntriesForRecipient(List<Long> sentTimestamps, RecipientId recipientId, int deviceId) { public void deleteEntriesForRecipient(List<Long> sentTimestamps, ServiceId serviceId, int deviceId) {
final var sql = """ final var sql = """
DELETE FROM %s AS l DELETE FROM %s AS l
WHERE l.content_id IN (SELECT _id FROM %s lc WHERE lc.timestamp = ?) AND l.recipient_id = ? AND l.device_id = ? WHERE l.content_id IN (SELECT _id FROM %s lc WHERE lc.timestamp = ?) AND l.uuid = ? AND l.device_id = ?
""".formatted(TABLE_MESSAGE_SEND_LOG, TABLE_MESSAGE_SEND_LOG_CONTENT); """.formatted(TABLE_MESSAGE_SEND_LOG, TABLE_MESSAGE_SEND_LOG_CONTENT);
try (final var connection = database.getConnection()) { try (final var connection = database.getConnection()) {
connection.setAutoCommit(false); connection.setAutoCommit(false);
try (final var statement = connection.prepareStatement(sql)) { try (final var statement = connection.prepareStatement(sql)) {
for (final var sentTimestamp : sentTimestamps) { for (final var sentTimestamp : sentTimestamps) {
statement.setLong(1, sentTimestamp); statement.setLong(1, sentTimestamp);
statement.setLong(2, recipientId.id()); statement.setBytes(2, serviceId.toByteArray());
statement.setInt(3, deviceId); statement.setInt(3, deviceId);
statement.executeUpdate(); statement.executeUpdate();
} }
@ -247,8 +242,8 @@ public class MessageSendLogStore implements AutoCloseable {
private RecipientDevices getRecipientDevices(final SendMessageResult sendMessageResult) { private RecipientDevices getRecipientDevices(final SendMessageResult sendMessageResult) {
if (sendMessageResult.isSuccess() && sendMessageResult.getSuccess().getContent().isPresent()) { if (sendMessageResult.isSuccess() && sendMessageResult.getSuccess().getContent().isPresent()) {
final var recipientId = recipientResolver.resolveRecipient(sendMessageResult.getAddress()); final var serviceId = sendMessageResult.getAddress().getServiceId();
return new RecipientDevices(recipientId, sendMessageResult.getSuccess().getDevices()); return new RecipientDevices(serviceId, sendMessageResult.getSuccess().getDevices());
} else { } else {
return null; return null;
} }
@ -332,13 +327,13 @@ public class MessageSendLogStore implements AutoCloseable {
final long contentId, final List<RecipientDevices> recipientDevices, final Connection connection final long contentId, final List<RecipientDevices> recipientDevices, final Connection connection
) throws SQLException { ) throws SQLException {
final var sql = """ final var sql = """
INSERT INTO %s (recipient_id, device_id, content_id) INSERT INTO %s (uuid, device_id, content_id)
VALUES (?,?,?) VALUES (?,?,?)
""".formatted(TABLE_MESSAGE_SEND_LOG); """.formatted(TABLE_MESSAGE_SEND_LOG);
try (final var statement = connection.prepareStatement(sql)) { try (final var statement = connection.prepareStatement(sql)) {
for (final var recipientDevice : recipientDevices) { for (final var recipientDevice : recipientDevices) {
for (final var deviceId : recipientDevice.deviceIds()) { for (final var deviceId : recipientDevice.deviceIds()) {
statement.setLong(1, recipientDevice.recipientId().id()); statement.setBytes(1, recipientDevice.serviceId().toByteArray());
statement.setInt(2, deviceId); statement.setInt(2, deviceId);
statement.setLong(3, contentId); statement.setLong(3, contentId);
statement.executeUpdate(); statement.executeUpdate();
@ -387,5 +382,5 @@ public class MessageSendLogStore implements AutoCloseable {
return new MessageSendLogEntry(groupId, content, contentHint, urgent); return new MessageSendLogEntry(groupId, content, contentHint, urgent);
} }
private record RecipientDevices(RecipientId recipientId, List<Integer> deviceIds) {} private record RecipientDevices(ServiceId serviceId, List<Integer> deviceIds) {}
} }

View file

@ -1,11 +1,14 @@
package org.asamk.signal.manager.storage.senderKeys; package org.asamk.signal.manager.storage.senderKeys;
import org.asamk.signal.manager.api.Pair; import org.asamk.signal.manager.api.Pair;
import org.asamk.signal.manager.helper.RecipientAddressResolver;
import org.asamk.signal.manager.storage.recipients.RecipientId;
import org.asamk.signal.manager.storage.recipients.RecipientResolver; import org.asamk.signal.manager.storage.recipients.RecipientResolver;
import org.signal.libsignal.protocol.InvalidMessageException; import org.signal.libsignal.protocol.InvalidMessageException;
import org.signal.libsignal.protocol.groups.state.SenderKeyRecord; import org.signal.libsignal.protocol.groups.state.SenderKeyRecord;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.whispersystems.signalservice.api.push.ServiceId;
import java.io.File; import java.io.File;
import java.io.FileInputStream; import java.io.FileInputStream;
@ -18,14 +21,15 @@ import java.util.UUID;
import java.util.regex.Matcher; import java.util.regex.Matcher;
import java.util.regex.Pattern; import java.util.regex.Pattern;
import static org.asamk.signal.manager.storage.senderKeys.SenderKeyRecordStore.Key;
public class LegacySenderKeyRecordStore { public class LegacySenderKeyRecordStore {
private final static Logger logger = LoggerFactory.getLogger(LegacySenderKeyRecordStore.class); private final static Logger logger = LoggerFactory.getLogger(LegacySenderKeyRecordStore.class);
public static void migrate( public static void migrate(
final File senderKeysPath, final RecipientResolver resolver, SenderKeyStore senderKeyStore final File senderKeysPath,
final RecipientResolver resolver,
final RecipientAddressResolver addressResolver,
final SenderKeyStore senderKeyStore
) { ) {
final var files = senderKeysPath.listFiles(); final var files = senderKeysPath.listFiles();
if (files == null) { if (files == null) {
@ -34,10 +38,13 @@ public class LegacySenderKeyRecordStore {
final var senderKeys = parseFileNames(files, resolver).stream().map(key -> { final var senderKeys = parseFileNames(files, resolver).stream().map(key -> {
final var record = loadSenderKeyLocked(key, senderKeysPath); final var record = loadSenderKeyLocked(key, senderKeysPath);
if (record == null) { final var uuid = addressResolver.resolveRecipientAddress(key.recipientId).uuid();
if (record == null || uuid.isEmpty()) {
return null; return null;
} }
return new Pair<>(key, record); return new Pair<>(new SenderKeyRecordStore.Key(ServiceId.from(uuid.get()),
key.deviceId,
key.distributionId), record);
}).filter(Objects::nonNull).toList(); }).filter(Objects::nonNull).toList();
senderKeyStore.addLegacySenderKeys(senderKeys); senderKeyStore.addLegacySenderKeys(senderKeys);
@ -98,4 +105,6 @@ public class LegacySenderKeyRecordStore {
return null; return null;
} }
} }
record Key(RecipientId recipientId, int deviceId, UUID distributionId) {}
} }

View file

@ -1,11 +1,13 @@
package org.asamk.signal.manager.storage.senderKeys; package org.asamk.signal.manager.storage.senderKeys;
import org.asamk.signal.manager.helper.RecipientAddressResolver;
import org.asamk.signal.manager.storage.Utils; import org.asamk.signal.manager.storage.Utils;
import org.asamk.signal.manager.storage.recipients.RecipientResolver; import org.asamk.signal.manager.storage.recipients.RecipientResolver;
import org.asamk.signal.manager.storage.senderKeys.SenderKeySharedStore.SenderKeySharedEntry; import org.asamk.signal.manager.storage.senderKeys.SenderKeySharedStore.SenderKeySharedEntry;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.whispersystems.signalservice.api.push.DistributionId; import org.whispersystems.signalservice.api.push.DistributionId;
import org.whispersystems.signalservice.api.push.ServiceId;
import java.io.File; import java.io.File;
import java.io.FileInputStream; import java.io.FileInputStream;
@ -21,7 +23,10 @@ public class LegacySenderKeySharedStore {
private final static Logger logger = LoggerFactory.getLogger(LegacySenderKeySharedStore.class); private final static Logger logger = LoggerFactory.getLogger(LegacySenderKeySharedStore.class);
public static void migrate( public static void migrate(
final File file, final RecipientResolver resolver, SenderKeyStore senderKeyStore final File file,
final RecipientResolver resolver,
final RecipientAddressResolver addressResolver,
final SenderKeyStore senderKeyStore
) { ) {
final var objectMapper = Utils.createStorageObjectMapper(); final var objectMapper = Utils.createStorageObjectMapper();
try (var inputStream = new FileInputStream(file)) { try (var inputStream = new FileInputStream(file)) {
@ -32,7 +37,11 @@ public class LegacySenderKeySharedStore {
if (recipientId == null) { if (recipientId == null) {
continue; continue;
} }
final var entry = new SenderKeySharedEntry(recipientId, senderKey.deviceId); final var uuid = addressResolver.resolveRecipientAddress(recipientId).uuid();
if (uuid.isEmpty()) {
continue;
}
final var entry = new SenderKeySharedEntry(ServiceId.from(uuid.get()), senderKey.deviceId);
final var distributionId = DistributionId.from(senderKey.distributionId); final var distributionId = DistributionId.from(senderKey.distributionId);
var entries = sharedSenderKeys.get(distributionId); var entries = sharedSenderKeys.get(distributionId);
if (entries == null) { if (entries == null) {

View file

@ -3,14 +3,13 @@ package org.asamk.signal.manager.storage.senderKeys;
import org.asamk.signal.manager.api.Pair; import org.asamk.signal.manager.api.Pair;
import org.asamk.signal.manager.storage.Database; import org.asamk.signal.manager.storage.Database;
import org.asamk.signal.manager.storage.Utils; import org.asamk.signal.manager.storage.Utils;
import org.asamk.signal.manager.storage.recipients.RecipientId;
import org.asamk.signal.manager.storage.recipients.RecipientResolver;
import org.signal.libsignal.protocol.InvalidMessageException; import org.signal.libsignal.protocol.InvalidMessageException;
import org.signal.libsignal.protocol.SignalProtocolAddress; import org.signal.libsignal.protocol.SignalProtocolAddress;
import org.signal.libsignal.protocol.groups.state.SenderKeyRecord; import org.signal.libsignal.protocol.groups.state.SenderKeyRecord;
import org.signal.libsignal.protocol.groups.state.SenderKeyStore; import org.signal.libsignal.protocol.groups.state.SenderKeyStore;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.whispersystems.signalservice.api.push.ServiceId;
import org.whispersystems.signalservice.api.util.UuidUtil; import org.whispersystems.signalservice.api.util.UuidUtil;
import java.sql.Connection; import java.sql.Connection;
@ -25,7 +24,6 @@ public class SenderKeyRecordStore implements SenderKeyStore {
private final static String TABLE_SENDER_KEY = "sender_key"; private final static String TABLE_SENDER_KEY = "sender_key";
private final Database database; private final Database database;
private final RecipientResolver resolver;
public static void createSql(Connection connection) throws SQLException { public static void createSql(Connection connection) throws SQLException {
// When modifying the CREATE statement here, also add a migration in AccountDatabase.java // When modifying the CREATE statement here, also add a migration in AccountDatabase.java
@ -33,22 +31,19 @@ public class SenderKeyRecordStore implements SenderKeyStore {
statement.executeUpdate(""" statement.executeUpdate("""
CREATE TABLE sender_key ( CREATE TABLE sender_key (
_id INTEGER PRIMARY KEY, _id INTEGER PRIMARY KEY,
recipient_id INTEGER NOT NULL REFERENCES recipient (_id) ON DELETE CASCADE, uuid BLOB NOT NULL,
device_id INTEGER NOT NULL, device_id INTEGER NOT NULL,
distribution_id BLOB NOT NULL, distribution_id BLOB NOT NULL,
record BLOB NOT NULL, record BLOB NOT NULL,
created_timestamp INTEGER NOT NULL, created_timestamp INTEGER NOT NULL,
UNIQUE(recipient_id, device_id, distribution_id) UNIQUE(uuid, device_id, distribution_id)
); ) STRICT;
"""); """);
} }
} }
SenderKeyRecordStore( SenderKeyRecordStore(final Database database) {
final Database database, final RecipientResolver resolver
) {
this.database = database; this.database = database;
this.resolver = resolver;
} }
@Override @Override
@ -75,17 +70,17 @@ public class SenderKeyRecordStore implements SenderKeyStore {
} }
} }
long getCreateTimeForKey(final RecipientId selfRecipientId, final int selfDeviceId, final UUID distributionId) { long getCreateTimeForKey(final ServiceId selfServiceId, final int selfDeviceId, final UUID distributionId) {
final var sql = ( final var sql = (
""" """
SELECT s.created_timestamp SELECT s.created_timestamp
FROM %s AS s FROM %s AS s
WHERE s.recipient_id = ? AND s.device_id = ? AND s.distribution_id = ? WHERE s.uuid = ? AND s.device_id = ? AND s.distribution_id = ?
""" """
).formatted(TABLE_SENDER_KEY); ).formatted(TABLE_SENDER_KEY);
try (final var connection = database.getConnection()) { try (final var connection = database.getConnection()) {
try (final var statement = connection.prepareStatement(sql)) { try (final var statement = connection.prepareStatement(sql)) {
statement.setLong(1, selfRecipientId.id()); statement.setBytes(1, selfServiceId.toByteArray());
statement.setInt(2, selfDeviceId); statement.setInt(2, selfDeviceId);
statement.setBytes(3, UuidUtil.toByteArray(distributionId)); statement.setBytes(3, UuidUtil.toByteArray(distributionId));
return Utils.executeQueryForOptional(statement, res -> res.getLong("created_timestamp")).orElse(-1L); return Utils.executeQueryForOptional(statement, res -> res.getLong("created_timestamp")).orElse(-1L);
@ -95,16 +90,16 @@ public class SenderKeyRecordStore implements SenderKeyStore {
} }
} }
void deleteSenderKey(final RecipientId recipientId, final UUID distributionId) { void deleteSenderKey(final ServiceId serviceId, final UUID distributionId) {
final var sql = ( final var sql = (
""" """
DELETE FROM %s AS s DELETE FROM %s AS s
WHERE s.recipient_id = ? AND s.distribution_id = ? WHERE s.uuid = ? AND s.distribution_id = ?
""" """
).formatted(TABLE_SENDER_KEY); ).formatted(TABLE_SENDER_KEY);
try (final var connection = database.getConnection()) { try (final var connection = database.getConnection()) {
try (final var statement = connection.prepareStatement(sql)) { try (final var statement = connection.prepareStatement(sql)) {
statement.setLong(1, recipientId.id()); statement.setBytes(1, serviceId.toByteArray());
statement.setBytes(2, UuidUtil.toByteArray(distributionId)); statement.setBytes(2, UuidUtil.toByteArray(distributionId));
statement.executeUpdate(); statement.executeUpdate();
} }
@ -126,33 +121,9 @@ public class SenderKeyRecordStore implements SenderKeyStore {
} }
} }
void deleteAllFor(final RecipientId recipientId) { void deleteAllFor(final ServiceId serviceId) {
try (final var connection = database.getConnection()) { try (final var connection = database.getConnection()) {
deleteAllFor(connection, recipientId); deleteAllFor(connection, serviceId);
} catch (SQLException e) {
throw new RuntimeException("Failed update sender key store", e);
}
}
void mergeRecipients(RecipientId recipientId, RecipientId toBeMergedRecipientId) {
try (final var connection = database.getConnection()) {
connection.setAutoCommit(false);
final var sql = """
UPDATE OR IGNORE %s
SET recipient_id = ?
WHERE recipient_id = ?
""".formatted(TABLE_SENDER_KEY);
try (final var statement = connection.prepareStatement(sql)) {
statement.setLong(1, recipientId.id());
statement.setLong(2, toBeMergedRecipientId.id());
final var rows = statement.executeUpdate();
if (rows > 0) {
logger.debug("Reassigned {} sender keys of to be merged recipient.", rows);
}
}
// Delete all conflicting sender keys now
deleteAllFor(connection, toBeMergedRecipientId);
connection.commit();
} catch (SQLException e) { } catch (SQLException e) {
throw new RuntimeException("Failed update sender key store", e); throw new RuntimeException("Failed update sender key store", e);
} }
@ -173,16 +144,9 @@ public class SenderKeyRecordStore implements SenderKeyStore {
logger.debug("Complete sender keys migration took {}ms", (System.nanoTime() - start) / 1000000); logger.debug("Complete sender keys migration took {}ms", (System.nanoTime() - start) / 1000000);
} }
/**
* @param identifier can be either a serialized uuid or an e164 phone number
*/
private RecipientId resolveRecipient(String identifier) {
return resolver.resolveRecipient(identifier);
}
private Key getKey(final SignalProtocolAddress address, final UUID distributionId) { private Key getKey(final SignalProtocolAddress address, final UUID distributionId) {
final var recipientId = resolveRecipient(address.getName()); final var serviceId = ServiceId.parseOrThrow(address.getName());
return new Key(recipientId, address.getDeviceId(), distributionId); return new Key(serviceId, address.getDeviceId(), distributionId);
} }
private SenderKeyRecord loadSenderKey(final Connection connection, final Key key) throws SQLException { private SenderKeyRecord loadSenderKey(final Connection connection, final Key key) throws SQLException {
@ -190,11 +154,11 @@ public class SenderKeyRecordStore implements SenderKeyStore {
""" """
SELECT s.record SELECT s.record
FROM %s AS s FROM %s AS s
WHERE s.recipient_id = ? AND s.device_id = ? AND s.distribution_id = ? WHERE s.uuid = ? AND s.device_id = ? AND s.distribution_id = ?
""" """
).formatted(TABLE_SENDER_KEY); ).formatted(TABLE_SENDER_KEY);
try (final var statement = connection.prepareStatement(sql)) { try (final var statement = connection.prepareStatement(sql)) {
statement.setLong(1, key.recipientId().id()); statement.setBytes(1, key.serviceId().toByteArray());
statement.setInt(2, key.deviceId()); statement.setInt(2, key.deviceId());
statement.setBytes(3, UuidUtil.toByteArray(key.distributionId())); statement.setBytes(3, UuidUtil.toByteArray(key.distributionId()));
return Utils.executeQueryForOptional(statement, this::getSenderKeyRecordFromResultSet).orElse(null); return Utils.executeQueryForOptional(statement, this::getSenderKeyRecordFromResultSet).orElse(null);
@ -207,11 +171,11 @@ public class SenderKeyRecordStore implements SenderKeyStore {
final var sqlUpdate = """ final var sqlUpdate = """
UPDATE %s UPDATE %s
SET record = ? SET record = ?
WHERE recipient_id = ? AND device_id = ? and distribution_id = ? WHERE uuid = ? AND device_id = ? and distribution_id = ?
""".formatted(TABLE_SENDER_KEY); """.formatted(TABLE_SENDER_KEY);
try (final var statement = connection.prepareStatement(sqlUpdate)) { try (final var statement = connection.prepareStatement(sqlUpdate)) {
statement.setBytes(1, senderKeyRecord.serialize()); statement.setBytes(1, senderKeyRecord.serialize());
statement.setLong(2, key.recipientId().id()); statement.setBytes(2, key.serviceId().toByteArray());
statement.setLong(3, key.deviceId()); statement.setLong(3, key.deviceId());
statement.setBytes(4, UuidUtil.toByteArray(key.distributionId())); statement.setBytes(4, UuidUtil.toByteArray(key.distributionId()));
final var rows = statement.executeUpdate(); final var rows = statement.executeUpdate();
@ -223,12 +187,12 @@ public class SenderKeyRecordStore implements SenderKeyStore {
// Record doesn't exist yet, creating a new one // Record doesn't exist yet, creating a new one
final var sqlInsert = ( final var sqlInsert = (
""" """
INSERT OR REPLACE INTO %s (recipient_id, device_id, distribution_id, record, created_timestamp) INSERT OR REPLACE INTO %s (uuid, device_id, distribution_id, record, created_timestamp)
VALUES (?, ?, ?, ?, ?) VALUES (?, ?, ?, ?, ?)
""" """
).formatted(TABLE_SENDER_KEY); ).formatted(TABLE_SENDER_KEY);
try (final var statement = connection.prepareStatement(sqlInsert)) { try (final var statement = connection.prepareStatement(sqlInsert)) {
statement.setLong(1, key.recipientId().id()); statement.setBytes(1, key.serviceId().toByteArray());
statement.setInt(2, key.deviceId()); statement.setInt(2, key.deviceId());
statement.setBytes(3, UuidUtil.toByteArray(key.distributionId())); statement.setBytes(3, UuidUtil.toByteArray(key.distributionId()));
statement.setBytes(4, senderKeyRecord.serialize()); statement.setBytes(4, senderKeyRecord.serialize());
@ -237,15 +201,15 @@ public class SenderKeyRecordStore implements SenderKeyStore {
} }
} }
private void deleteAllFor(final Connection connection, final RecipientId recipientId) throws SQLException { private void deleteAllFor(final Connection connection, final ServiceId serviceId) throws SQLException {
final var sql = ( final var sql = (
""" """
DELETE FROM %s AS s DELETE FROM %s AS s
WHERE s.recipient_id = ? WHERE s.uuid = ?
""" """
).formatted(TABLE_SENDER_KEY); ).formatted(TABLE_SENDER_KEY);
try (final var statement = connection.prepareStatement(sql)) { try (final var statement = connection.prepareStatement(sql)) {
statement.setLong(1, recipientId.id()); statement.setBytes(1, serviceId.toByteArray());
statement.executeUpdate(); statement.executeUpdate();
} }
} }
@ -261,5 +225,5 @@ public class SenderKeyRecordStore implements SenderKeyStore {
} }
} }
record Key(RecipientId recipientId, int deviceId, UUID distributionId) {} record Key(ServiceId serviceId, int deviceId, UUID distributionId) {}
} }

View file

@ -1,15 +1,12 @@
package org.asamk.signal.manager.storage.senderKeys; package org.asamk.signal.manager.storage.senderKeys;
import org.asamk.signal.manager.helper.RecipientAddressResolver;
import org.asamk.signal.manager.storage.Database; import org.asamk.signal.manager.storage.Database;
import org.asamk.signal.manager.storage.Utils; import org.asamk.signal.manager.storage.Utils;
import org.asamk.signal.manager.storage.recipients.RecipientId;
import org.asamk.signal.manager.storage.recipients.RecipientIdCreator;
import org.asamk.signal.manager.storage.recipients.RecipientResolver;
import org.signal.libsignal.protocol.SignalProtocolAddress; import org.signal.libsignal.protocol.SignalProtocolAddress;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.whispersystems.signalservice.api.push.DistributionId; import org.whispersystems.signalservice.api.push.DistributionId;
import org.whispersystems.signalservice.api.push.ServiceId;
import org.whispersystems.signalservice.api.util.UuidUtil; import org.whispersystems.signalservice.api.util.UuidUtil;
import java.sql.Connection; import java.sql.Connection;
@ -26,9 +23,6 @@ public class SenderKeySharedStore {
private final static String TABLE_SENDER_KEY_SHARED = "sender_key_shared"; private final static String TABLE_SENDER_KEY_SHARED = "sender_key_shared";
private final Database database; private final Database database;
private final RecipientIdCreator recipientIdCreator;
private final RecipientResolver resolver;
private final RecipientAddressResolver addressResolver;
public static void createSql(Connection connection) throws SQLException { public static void createSql(Connection connection) throws SQLException {
// When modifying the CREATE statement here, also add a migration in AccountDatabase.java // When modifying the CREATE statement here, also add a migration in AccountDatabase.java
@ -36,33 +30,25 @@ public class SenderKeySharedStore {
statement.executeUpdate(""" statement.executeUpdate("""
CREATE TABLE sender_key_shared ( CREATE TABLE sender_key_shared (
_id INTEGER PRIMARY KEY, _id INTEGER PRIMARY KEY,
recipient_id INTEGER NOT NULL REFERENCES recipient (_id) ON DELETE CASCADE, uuid BLOB NOT NULL,
device_id INTEGER NOT NULL, device_id INTEGER NOT NULL,
distribution_id BLOB NOT NULL, distribution_id BLOB NOT NULL,
timestamp INTEGER NOT NULL, timestamp INTEGER NOT NULL,
UNIQUE(recipient_id, device_id, distribution_id) UNIQUE(uuid, device_id, distribution_id)
); ) STRICT;
"""); """);
} }
} }
SenderKeySharedStore( SenderKeySharedStore(final Database database) {
final Database database,
final RecipientIdCreator recipientIdCreator,
final RecipientAddressResolver addressResolver,
final RecipientResolver resolver
) {
this.database = database; this.database = database;
this.recipientIdCreator = recipientIdCreator;
this.addressResolver = addressResolver;
this.resolver = resolver;
} }
public Set<SignalProtocolAddress> getSenderKeySharedWith(final DistributionId distributionId) { public Set<SignalProtocolAddress> getSenderKeySharedWith(final DistributionId distributionId) {
try (final var connection = database.getConnection()) { try (final var connection = database.getConnection()) {
final var sql = ( final var sql = (
""" """
SELECT s.recipient_id, s.device_id SELECT s.uuid, s.device_id
FROM %s AS s FROM %s AS s
WHERE s.distribution_id = ? WHERE s.distribution_id = ?
""" """
@ -70,8 +56,7 @@ public class SenderKeySharedStore {
try (final var statement = connection.prepareStatement(sql)) { try (final var statement = connection.prepareStatement(sql)) {
statement.setBytes(1, UuidUtil.toByteArray(distributionId.asUuid())); statement.setBytes(1, UuidUtil.toByteArray(distributionId.asUuid()));
return Utils.executeQueryForStream(statement, this::getSenderKeySharedEntryFromResultSet) return Utils.executeQueryForStream(statement, this::getSenderKeySharedEntryFromResultSet)
.map(k -> new SignalProtocolAddress(addressResolver.resolveRecipientAddress(k.recipientId()) .map(k -> k.serviceId.toProtocolAddress(k.deviceId()))
.getIdentifier(), k.deviceId()))
.collect(Collectors.toSet()); .collect(Collectors.toSet());
} }
} catch (SQLException e) { } catch (SQLException e) {
@ -83,7 +68,7 @@ public class SenderKeySharedStore {
final DistributionId distributionId, final Collection<SignalProtocolAddress> addresses final DistributionId distributionId, final Collection<SignalProtocolAddress> addresses
) { ) {
final var newEntries = addresses.stream() final var newEntries = addresses.stream()
.map(a -> new SenderKeySharedEntry(resolver.resolveRecipient(a.getName()), a.getDeviceId())) .map(a -> new SenderKeySharedEntry(ServiceId.parseOrThrow(a.getName()), a.getDeviceId()))
.collect(Collectors.toSet()); .collect(Collectors.toSet());
try (final var connection = database.getConnection()) { try (final var connection = database.getConnection()) {
@ -97,7 +82,8 @@ public class SenderKeySharedStore {
public void clearSenderKeySharedWith(final Collection<SignalProtocolAddress> addresses) { public void clearSenderKeySharedWith(final Collection<SignalProtocolAddress> addresses) {
final var entriesToDelete = addresses.stream() final var entriesToDelete = addresses.stream()
.map(a -> new SenderKeySharedEntry(resolver.resolveRecipient(a.getName()), a.getDeviceId())) .filter(a -> UuidUtil.isUuid(a.getName()))
.map(a -> new SenderKeySharedEntry(ServiceId.parseOrThrow(a.getName()), a.getDeviceId()))
.collect(Collectors.toSet()); .collect(Collectors.toSet());
try (final var connection = database.getConnection()) { try (final var connection = database.getConnection()) {
@ -105,12 +91,12 @@ public class SenderKeySharedStore {
final var sql = ( final var sql = (
""" """
DELETE FROM %s AS s DELETE FROM %s AS s
WHERE recipient_id = ? AND device_id = ? WHERE uuid = ? AND device_id = ?
""" """
).formatted(TABLE_SENDER_KEY_SHARED); ).formatted(TABLE_SENDER_KEY_SHARED);
try (final var statement = connection.prepareStatement(sql)) { try (final var statement = connection.prepareStatement(sql)) {
for (final var entry : entriesToDelete) { for (final var entry : entriesToDelete) {
statement.setLong(1, entry.recipientId().id()); statement.setBytes(1, entry.serviceId().toByteArray());
statement.setInt(2, entry.deviceId()); statement.setInt(2, entry.deviceId());
statement.executeUpdate(); statement.executeUpdate();
} }
@ -136,16 +122,16 @@ public class SenderKeySharedStore {
} }
} }
public void deleteAllFor(final RecipientId recipientId) { public void deleteAllFor(final ServiceId serviceId) {
try (final var connection = database.getConnection()) { try (final var connection = database.getConnection()) {
final var sql = ( final var sql = (
""" """
DELETE FROM %s AS s DELETE FROM %s AS s
WHERE recipient_id = ? WHERE uuid = ?
""" """
).formatted(TABLE_SENDER_KEY_SHARED); ).formatted(TABLE_SENDER_KEY_SHARED);
try (final var statement = connection.prepareStatement(sql)) { try (final var statement = connection.prepareStatement(sql)) {
statement.setLong(1, recipientId.id()); statement.setBytes(1, serviceId.toByteArray());
statement.executeUpdate(); statement.executeUpdate();
} }
} catch (SQLException e) { } catch (SQLException e) {
@ -154,17 +140,17 @@ public class SenderKeySharedStore {
} }
public void deleteSharedWith( public void deleteSharedWith(
final RecipientId recipientId, final int deviceId, final DistributionId distributionId final ServiceId serviceId, final int deviceId, final DistributionId distributionId
) { ) {
try (final var connection = database.getConnection()) { try (final var connection = database.getConnection()) {
final var sql = ( final var sql = (
""" """
DELETE FROM %s AS s DELETE FROM %s AS s
WHERE recipient_id = ? AND device_id = ? AND distribution_id = ? WHERE uuid = ? AND device_id = ? AND distribution_id = ?
""" """
).formatted(TABLE_SENDER_KEY_SHARED); ).formatted(TABLE_SENDER_KEY_SHARED);
try (final var statement = connection.prepareStatement(sql)) { try (final var statement = connection.prepareStatement(sql)) {
statement.setLong(1, recipientId.id()); statement.setBytes(1, serviceId.toByteArray());
statement.setInt(2, deviceId); statement.setInt(2, deviceId);
statement.setBytes(3, UuidUtil.toByteArray(distributionId.asUuid())); statement.setBytes(3, UuidUtil.toByteArray(distributionId.asUuid()));
statement.executeUpdate(); statement.executeUpdate();
@ -191,25 +177,6 @@ public class SenderKeySharedStore {
} }
} }
public void mergeRecipients(RecipientId recipientId, RecipientId toBeMergedRecipientId) {
try (final var connection = database.getConnection()) {
final var sql = (
"""
UPDATE OR REPLACE %s
SET recipient_id = ?
WHERE recipient_id = ?
"""
).formatted(TABLE_SENDER_KEY_SHARED);
try (final var statement = connection.prepareStatement(sql)) {
statement.setLong(1, recipientId.id());
statement.setLong(2, toBeMergedRecipientId.id());
statement.executeUpdate();
}
} catch (SQLException e) {
throw new RuntimeException("Failed update shared sender key store", e);
}
}
void addLegacySenderKeysShared(final Map<DistributionId, Set<SenderKeySharedEntry>> sharedSenderKeys) { void addLegacySenderKeysShared(final Map<DistributionId, Set<SenderKeySharedEntry>> sharedSenderKeys) {
logger.debug("Migrating legacy sender keys shared to database"); logger.debug("Migrating legacy sender keys shared to database");
long start = System.nanoTime(); long start = System.nanoTime();
@ -230,13 +197,13 @@ public class SenderKeySharedStore {
) throws SQLException { ) throws SQLException {
final var sql = ( final var sql = (
""" """
INSERT OR REPLACE INTO %s (recipient_id, device_id, distribution_id, timestamp) INSERT OR REPLACE INTO %s (uuid, device_id, distribution_id, timestamp)
VALUES (?, ?, ?, ?) VALUES (?, ?, ?, ?)
""" """
).formatted(TABLE_SENDER_KEY_SHARED); ).formatted(TABLE_SENDER_KEY_SHARED);
try (final var statement = connection.prepareStatement(sql)) { try (final var statement = connection.prepareStatement(sql)) {
for (final var entry : newEntries) { for (final var entry : newEntries) {
statement.setLong(1, entry.recipientId().id()); statement.setBytes(1, entry.serviceId().toByteArray());
statement.setInt(2, entry.deviceId()); statement.setInt(2, entry.deviceId());
statement.setBytes(3, UuidUtil.toByteArray(distributionId.asUuid())); statement.setBytes(3, UuidUtil.toByteArray(distributionId.asUuid()));
statement.setLong(4, System.currentTimeMillis()); statement.setLong(4, System.currentTimeMillis());
@ -246,10 +213,10 @@ public class SenderKeySharedStore {
} }
private SenderKeySharedEntry getSenderKeySharedEntryFromResultSet(ResultSet resultSet) throws SQLException { private SenderKeySharedEntry getSenderKeySharedEntryFromResultSet(ResultSet resultSet) throws SQLException {
final var recipientId = resultSet.getLong("recipient_id"); final var serviceId = ServiceId.parseOrThrow(resultSet.getBytes("uuid"));
final var deviceId = resultSet.getInt("device_id"); final var deviceId = resultSet.getInt("device_id");
return new SenderKeySharedEntry(recipientIdCreator.create(recipientId), deviceId); return new SenderKeySharedEntry(serviceId, deviceId);
} }
record SenderKeySharedEntry(RecipientId recipientId, int deviceId) {} record SenderKeySharedEntry(ServiceId serviceId, int deviceId) {}
} }

View file

@ -1,15 +1,12 @@
package org.asamk.signal.manager.storage.senderKeys; package org.asamk.signal.manager.storage.senderKeys;
import org.asamk.signal.manager.api.Pair; import org.asamk.signal.manager.api.Pair;
import org.asamk.signal.manager.helper.RecipientAddressResolver;
import org.asamk.signal.manager.storage.Database; import org.asamk.signal.manager.storage.Database;
import org.asamk.signal.manager.storage.recipients.RecipientId;
import org.asamk.signal.manager.storage.recipients.RecipientIdCreator;
import org.asamk.signal.manager.storage.recipients.RecipientResolver;
import org.signal.libsignal.protocol.SignalProtocolAddress; import org.signal.libsignal.protocol.SignalProtocolAddress;
import org.signal.libsignal.protocol.groups.state.SenderKeyRecord; import org.signal.libsignal.protocol.groups.state.SenderKeyRecord;
import org.whispersystems.signalservice.api.SignalServiceSenderKeyStore; import org.whispersystems.signalservice.api.SignalServiceSenderKeyStore;
import org.whispersystems.signalservice.api.push.DistributionId; import org.whispersystems.signalservice.api.push.DistributionId;
import org.whispersystems.signalservice.api.push.ServiceId;
import java.util.Collection; import java.util.Collection;
import java.util.Map; import java.util.Map;
@ -21,14 +18,9 @@ public class SenderKeyStore implements SignalServiceSenderKeyStore {
private final SenderKeyRecordStore senderKeyRecordStore; private final SenderKeyRecordStore senderKeyRecordStore;
private final SenderKeySharedStore senderKeySharedStore; private final SenderKeySharedStore senderKeySharedStore;
public SenderKeyStore( public SenderKeyStore(final Database database) {
final Database database, this.senderKeyRecordStore = new SenderKeyRecordStore(database);
final RecipientAddressResolver addressResolver, this.senderKeySharedStore = new SenderKeySharedStore(database);
final RecipientResolver resolver,
final RecipientIdCreator recipientIdCreator
) {
this.senderKeyRecordStore = new SenderKeyRecordStore(database, resolver);
this.senderKeySharedStore = new SenderKeySharedStore(database, recipientIdCreator, addressResolver, resolver);
} }
@Override @Override
@ -65,31 +57,26 @@ public class SenderKeyStore implements SignalServiceSenderKeyStore {
senderKeyRecordStore.deleteAll(); senderKeyRecordStore.deleteAll();
} }
public void deleteAll(RecipientId recipientId) { public void deleteAll(ServiceId serviceId) {
senderKeySharedStore.deleteAllFor(recipientId); senderKeySharedStore.deleteAllFor(serviceId);
senderKeyRecordStore.deleteAllFor(recipientId); senderKeyRecordStore.deleteAllFor(serviceId);
} }
public void deleteSharedWith(RecipientId recipientId) { public void deleteSharedWith(ServiceId serviceId) {
senderKeySharedStore.deleteAllFor(recipientId); senderKeySharedStore.deleteAllFor(serviceId);
} }
public void deleteSharedWith(RecipientId recipientId, int deviceId, DistributionId distributionId) { public void deleteSharedWith(ServiceId serviceId, int deviceId, DistributionId distributionId) {
senderKeySharedStore.deleteSharedWith(recipientId, deviceId, distributionId); senderKeySharedStore.deleteSharedWith(serviceId, deviceId, distributionId);
} }
public void deleteOurKey(RecipientId selfRecipientId, DistributionId distributionId) { public void deleteOurKey(ServiceId selfServiceId, DistributionId distributionId) {
senderKeySharedStore.deleteAllFor(distributionId); senderKeySharedStore.deleteAllFor(distributionId);
senderKeyRecordStore.deleteSenderKey(selfRecipientId, distributionId.asUuid()); senderKeyRecordStore.deleteSenderKey(selfServiceId, distributionId.asUuid());
} }
public long getCreateTimeForOurKey(RecipientId selfRecipientId, int deviceId, DistributionId distributionId) { public long getCreateTimeForOurKey(ServiceId selfServiceId, int deviceId, DistributionId distributionId) {
return senderKeyRecordStore.getCreateTimeForKey(selfRecipientId, deviceId, distributionId.asUuid()); return senderKeyRecordStore.getCreateTimeForKey(selfServiceId, deviceId, distributionId.asUuid());
}
public void mergeRecipients(RecipientId recipientId, RecipientId toBeMergedRecipientId) {
senderKeySharedStore.mergeRecipients(recipientId, toBeMergedRecipientId);
senderKeyRecordStore.mergeRecipients(recipientId, toBeMergedRecipientId);
} }
void addLegacySenderKeys(final Collection<Pair<SenderKeyRecordStore.Key, SenderKeyRecord>> senderKeys) { void addLegacySenderKeys(final Collection<Pair<SenderKeyRecordStore.Key, SenderKeyRecord>> senderKeys) {

View file

@ -1,12 +1,14 @@
package org.asamk.signal.manager.storage.sessions; package org.asamk.signal.manager.storage.sessions;
import org.asamk.signal.manager.api.Pair; import org.asamk.signal.manager.api.Pair;
import org.asamk.signal.manager.helper.RecipientAddressResolver;
import org.asamk.signal.manager.storage.recipients.RecipientId;
import org.asamk.signal.manager.storage.recipients.RecipientResolver; import org.asamk.signal.manager.storage.recipients.RecipientResolver;
import org.asamk.signal.manager.storage.sessions.SessionStore.Key;
import org.asamk.signal.manager.util.IOUtils; import org.asamk.signal.manager.util.IOUtils;
import org.signal.libsignal.protocol.state.SessionRecord; import org.signal.libsignal.protocol.state.SessionRecord;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.whispersystems.signalservice.api.push.ServiceId;
import java.io.File; import java.io.File;
import java.io.FileInputStream; import java.io.FileInputStream;
@ -24,15 +26,19 @@ public class LegacySessionStore {
private final static Logger logger = LoggerFactory.getLogger(LegacySessionStore.class); private final static Logger logger = LoggerFactory.getLogger(LegacySessionStore.class);
public static void migrate( public static void migrate(
final File sessionsPath, final RecipientResolver resolver, final SessionStore sessionStore final File sessionsPath,
final RecipientResolver resolver,
final RecipientAddressResolver addressResolver,
final SessionStore sessionStore
) { ) {
final var keys = getKeysLocked(sessionsPath, resolver); final var keys = getKeysLocked(sessionsPath, resolver);
final var sessions = keys.stream().map(key -> { final var sessions = keys.stream().map(key -> {
final var record = loadSessionLocked(key, sessionsPath); final var record = loadSessionLocked(key, sessionsPath);
if (record == null) { final var uuid = addressResolver.resolveRecipientAddress(key.recipientId).uuid();
if (record == null || uuid.isEmpty()) {
return null; return null;
} }
return new Pair<>(key, record); return new Pair<>(new SessionStore.Key(ServiceId.from(uuid.get()), key.deviceId()), record);
}).filter(Objects::nonNull).toList(); }).filter(Objects::nonNull).toList();
sessionStore.addLegacySessions(sessions); sessionStore.addLegacySessions(sessions);
deleteAllSessions(sessionsPath); deleteAllSessions(sessionsPath);
@ -104,4 +110,6 @@ public class LegacySessionStore {
return null; return null;
} }
} }
record Key(RecipientId recipientId, int deviceId) {}
} }

View file

@ -3,9 +3,6 @@ package org.asamk.signal.manager.storage.sessions;
import org.asamk.signal.manager.api.Pair; import org.asamk.signal.manager.api.Pair;
import org.asamk.signal.manager.storage.Database; import org.asamk.signal.manager.storage.Database;
import org.asamk.signal.manager.storage.Utils; import org.asamk.signal.manager.storage.Utils;
import org.asamk.signal.manager.storage.recipients.RecipientId;
import org.asamk.signal.manager.storage.recipients.RecipientIdCreator;
import org.asamk.signal.manager.storage.recipients.RecipientResolver;
import org.signal.libsignal.protocol.InvalidMessageException; import org.signal.libsignal.protocol.InvalidMessageException;
import org.signal.libsignal.protocol.NoSessionException; import org.signal.libsignal.protocol.NoSessionException;
import org.signal.libsignal.protocol.SignalProtocolAddress; import org.signal.libsignal.protocol.SignalProtocolAddress;
@ -15,7 +12,9 @@ import org.signal.libsignal.protocol.state.SessionRecord;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.whispersystems.signalservice.api.SignalServiceSessionStore; import org.whispersystems.signalservice.api.SignalServiceSessionStore;
import org.whispersystems.signalservice.api.push.ServiceId;
import org.whispersystems.signalservice.api.push.ServiceIdType; import org.whispersystems.signalservice.api.push.ServiceIdType;
import org.whispersystems.signalservice.api.util.UuidUtil;
import java.sql.Connection; import java.sql.Connection;
import java.sql.ResultSet; import java.sql.ResultSet;
@ -38,8 +37,6 @@ public class SessionStore implements SignalServiceSessionStore {
private final Database database; private final Database database;
private final int accountIdType; private final int accountIdType;
private final RecipientResolver resolver;
private final RecipientIdCreator recipientIdCreator;
public static void createSql(Connection connection) throws SQLException { public static void createSql(Connection connection) throws SQLException {
// When modifying the CREATE statement here, also add a migration in AccountDatabase.java // When modifying the CREATE statement here, also add a migration in AccountDatabase.java
@ -48,25 +45,18 @@ public class SessionStore implements SignalServiceSessionStore {
CREATE TABLE session ( CREATE TABLE session (
_id INTEGER PRIMARY KEY, _id INTEGER PRIMARY KEY,
account_id_type INTEGER NOT NULL, account_id_type INTEGER NOT NULL,
recipient_id INTEGER NOT NULL REFERENCES recipient (_id) ON DELETE CASCADE, uuid BLOB NOT NULL,
device_id INTEGER NOT NULL, device_id INTEGER NOT NULL,
record BLOB NOT NULL, record BLOB NOT NULL,
UNIQUE(account_id_type, recipient_id, device_id) UNIQUE(account_id_type, uuid, device_id)
); ) STRICT;
"""); """);
} }
} }
public SessionStore( public SessionStore(final Database database, final ServiceIdType serviceIdType) {
final Database database,
final ServiceIdType serviceIdType,
final RecipientResolver resolver,
final RecipientIdCreator recipientIdCreator
) {
this.database = database; this.database = database;
this.accountIdType = Utils.getAccountIdType(serviceIdType); this.accountIdType = Utils.getAccountIdType(serviceIdType);
this.resolver = resolver;
this.recipientIdCreator = recipientIdCreator;
} }
@Override @Override
@ -111,19 +101,19 @@ public class SessionStore implements SignalServiceSessionStore {
@Override @Override
public List<Integer> getSubDeviceSessions(String name) { public List<Integer> getSubDeviceSessions(String name) {
final var recipientId = resolver.resolveRecipient(name); final var serviceId = ServiceId.parseOrThrow(name);
// get all sessions for recipient except primary device session // get all sessions for recipient except primary device session
final var sql = ( final var sql = (
""" """
SELECT s.device_id SELECT s.device_id
FROM %s AS s FROM %s AS s
WHERE s.account_id_type = ? AND s.recipient_id = ? AND s.device_id != 1 WHERE s.account_id_type = ? AND s.uuid = ? AND s.device_id != 1
""" """
).formatted(TABLE_SESSION); ).formatted(TABLE_SESSION);
try (final var connection = database.getConnection()) { try (final var connection = database.getConnection()) {
try (final var statement = connection.prepareStatement(sql)) { try (final var statement = connection.prepareStatement(sql)) {
statement.setInt(1, accountIdType); statement.setInt(1, accountIdType);
statement.setLong(2, recipientId.id()); statement.setBytes(2, serviceId.toByteArray());
return Utils.executeQueryForStream(statement, res -> res.getInt("device_id")).toList(); return Utils.executeQueryForStream(statement, res -> res.getInt("device_id")).toList();
} }
} catch (SQLException e) { } catch (SQLException e) {
@ -131,8 +121,8 @@ public class SessionStore implements SignalServiceSessionStore {
} }
} }
public boolean isCurrentRatchetKey(RecipientId recipientId, int deviceId, ECPublicKey ratchetKey) { public boolean isCurrentRatchetKey(ServiceId serviceId, int deviceId, ECPublicKey ratchetKey) {
final var key = new Key(recipientId, deviceId); final var key = new Key(serviceId, deviceId);
try (final var connection = database.getConnection()) { try (final var connection = database.getConnection()) {
final var session = loadSession(connection, key); final var session = loadSession(connection, key);
@ -181,13 +171,13 @@ public class SessionStore implements SignalServiceSessionStore {
@Override @Override
public void deleteAllSessions(String name) { public void deleteAllSessions(String name) {
final var recipientId = resolver.resolveRecipient(name); final var serviceId = ServiceId.parseOrThrow(name);
deleteAllSessions(recipientId); deleteAllSessions(serviceId);
} }
public void deleteAllSessions(RecipientId recipientId) { public void deleteAllSessions(ServiceId serviceId) {
try (final var connection = database.getConnection()) { try (final var connection = database.getConnection()) {
deleteAllSessions(connection, recipientId); deleteAllSessions(connection, serviceId);
} catch (SQLException e) { } catch (SQLException e) {
throw new RuntimeException("Failed update session store", e); throw new RuntimeException("Failed update session store", e);
} }
@ -195,6 +185,10 @@ public class SessionStore implements SignalServiceSessionStore {
@Override @Override
public void archiveSession(final SignalProtocolAddress address) { public void archiveSession(final SignalProtocolAddress address) {
if (!UuidUtil.isUuid(address.getName())) {
return;
}
final var key = getKey(address); final var key = getKey(address);
try (final var connection = database.getConnection()) { try (final var connection = database.getConnection()) {
@ -212,19 +206,17 @@ public class SessionStore implements SignalServiceSessionStore {
@Override @Override
public Set<SignalProtocolAddress> getAllAddressesWithActiveSessions(final List<String> addressNames) { public Set<SignalProtocolAddress> getAllAddressesWithActiveSessions(final List<String> addressNames) {
final var recipientIdToNameMap = addressNames.stream() final var serviceIdsCommaSeparated = addressNames.stream()
.collect(Collectors.toMap(resolver::resolveRecipient, name -> name)); .map(ServiceId::parseOrThrow)
final var recipientIdsCommaSeparated = recipientIdToNameMap.keySet() .map(ServiceId::toString)
.stream()
.map(recipientId -> String.valueOf(recipientId.id()))
.collect(Collectors.joining(",")); .collect(Collectors.joining(","));
final var sql = ( final var sql = (
""" """
SELECT s.recipient_id, s.device_id, s.record SELECT s.uuid, s.device_id, s.record
FROM %s AS s FROM %s AS s
WHERE s.account_id_type = ? AND s.recipient_id IN (%s) WHERE s.account_id_type = ? AND s.uuid IN (%s)
""" """
).formatted(TABLE_SESSION, recipientIdsCommaSeparated); ).formatted(TABLE_SESSION, serviceIdsCommaSeparated);
try (final var connection = database.getConnection()) { try (final var connection = database.getConnection()) {
try (final var statement = connection.prepareStatement(sql)) { try (final var statement = connection.prepareStatement(sql)) {
statement.setInt(1, accountIdType); statement.setInt(1, accountIdType);
@ -232,8 +224,7 @@ public class SessionStore implements SignalServiceSessionStore {
res -> new Pair<>(getKeyFromResultSet(res), getSessionRecordFromResultSet(res))) res -> new Pair<>(getKeyFromResultSet(res), getSessionRecordFromResultSet(res)))
.filter(pair -> isActive(pair.second())) .filter(pair -> isActive(pair.second()))
.map(Pair::first) .map(Pair::first)
.map(key -> new SignalProtocolAddress(recipientIdToNameMap.get(key.recipientId), .map(key -> key.serviceId().toProtocolAddress(key.deviceId()))
key.deviceId()))
.collect(Collectors.toSet()); .collect(Collectors.toSet());
} }
} catch (SQLException e) { } catch (SQLException e) {
@ -244,7 +235,7 @@ public class SessionStore implements SignalServiceSessionStore {
public void archiveAllSessions() { public void archiveAllSessions() {
final var sql = ( final var sql = (
""" """
SELECT s.recipient_id, s.device_id, s.record SELECT s.uuid, s.device_id, s.record
FROM %s AS s FROM %s AS s
WHERE s.account_id_type = ? WHERE s.account_id_type = ?
""" """
@ -267,12 +258,12 @@ public class SessionStore implements SignalServiceSessionStore {
} }
} }
public void archiveSessions(final RecipientId recipientId) { public void archiveSessions(final ServiceId serviceId) {
final var sql = ( final var sql = (
""" """
SELECT s.recipient_id, s.device_id, s.record SELECT s.uuid, s.device_id, s.record
FROM %s AS s FROM %s AS s
WHERE s.account_id_type = ? AND s.recipient_id = ? WHERE s.account_id_type = ? AND s.uuid = ?
""" """
).formatted(TABLE_SESSION); ).formatted(TABLE_SESSION);
try (final var connection = database.getConnection()) { try (final var connection = database.getConnection()) {
@ -280,7 +271,7 @@ public class SessionStore implements SignalServiceSessionStore {
final List<Pair<Key, SessionRecord>> records; final List<Pair<Key, SessionRecord>> records;
try (final var statement = connection.prepareStatement(sql)) { try (final var statement = connection.prepareStatement(sql)) {
statement.setInt(1, accountIdType); statement.setInt(1, accountIdType);
statement.setLong(2, recipientId.id()); statement.setBytes(2, serviceId.toByteArray());
records = Utils.executeQueryForStream(statement, records = Utils.executeQueryForStream(statement,
res -> new Pair<>(getKeyFromResultSet(res), getSessionRecordFromResultSet(res))).toList(); res -> new Pair<>(getKeyFromResultSet(res), getSessionRecordFromResultSet(res))).toList();
} }
@ -294,35 +285,6 @@ public class SessionStore implements SignalServiceSessionStore {
} }
} }
public void mergeRecipients(RecipientId recipientId, RecipientId toBeMergedRecipientId) {
try (final var connection = database.getConnection()) {
connection.setAutoCommit(false);
synchronized (cachedSessions) {
cachedSessions.clear();
}
final var sql = """
UPDATE OR IGNORE %s
SET recipient_id = ?
WHERE account_id_type = ? AND recipient_id = ?
""".formatted(TABLE_SESSION);
try (final var statement = connection.prepareStatement(sql)) {
statement.setLong(1, recipientId.id());
statement.setInt(2, accountIdType);
statement.setLong(3, toBeMergedRecipientId.id());
final var rows = statement.executeUpdate();
if (rows > 0) {
logger.debug("Reassigned {} sessions of to be merged recipient.", rows);
}
}
// Delete all conflicting sessions now
deleteAllSessions(connection, toBeMergedRecipientId);
connection.commit();
} catch (SQLException e) {
throw new RuntimeException("Failed update session store", e);
}
}
void addLegacySessions(final Collection<Pair<Key, SessionRecord>> sessions) { void addLegacySessions(final Collection<Pair<Key, SessionRecord>> sessions) {
logger.debug("Migrating legacy sessions to database"); logger.debug("Migrating legacy sessions to database");
long start = System.nanoTime(); long start = System.nanoTime();
@ -339,8 +301,8 @@ public class SessionStore implements SignalServiceSessionStore {
} }
private Key getKey(final SignalProtocolAddress address) { private Key getKey(final SignalProtocolAddress address) {
final var recipientId = resolver.resolveRecipient(address.getName()); final var serviceId = ServiceId.parseOrThrow(address.getName());
return new Key(recipientId, address.getDeviceId()); return new Key(serviceId, address.getDeviceId());
} }
private SessionRecord loadSession(Connection connection, final Key key) throws SQLException { private SessionRecord loadSession(Connection connection, final Key key) throws SQLException {
@ -354,21 +316,21 @@ public class SessionStore implements SignalServiceSessionStore {
""" """
SELECT s.record SELECT s.record
FROM %s AS s FROM %s AS s
WHERE s.account_id_type = ? AND s.recipient_id = ? AND s.device_id = ? WHERE s.account_id_type = ? AND s.uuid = ? AND s.device_id = ?
""" """
).formatted(TABLE_SESSION); ).formatted(TABLE_SESSION);
try (final var statement = connection.prepareStatement(sql)) { try (final var statement = connection.prepareStatement(sql)) {
statement.setInt(1, accountIdType); statement.setInt(1, accountIdType);
statement.setLong(2, key.recipientId().id()); statement.setBytes(2, key.serviceId().toByteArray());
statement.setInt(3, key.deviceId()); statement.setInt(3, key.deviceId());
return Utils.executeQueryForOptional(statement, this::getSessionRecordFromResultSet).orElse(null); return Utils.executeQueryForOptional(statement, this::getSessionRecordFromResultSet).orElse(null);
} }
} }
private Key getKeyFromResultSet(ResultSet resultSet) throws SQLException { private Key getKeyFromResultSet(ResultSet resultSet) throws SQLException {
final var recipientId = resultSet.getLong("recipient_id"); final var serviceId = ServiceId.parseOrThrow(resultSet.getBytes("uuid"));
final var deviceId = resultSet.getInt("device_id"); final var deviceId = resultSet.getInt("device_id");
return new Key(recipientIdCreator.create(recipientId), deviceId); return new Key(serviceId, deviceId);
} }
private SessionRecord getSessionRecordFromResultSet(ResultSet resultSet) throws SQLException { private SessionRecord getSessionRecordFromResultSet(ResultSet resultSet) throws SQLException {
@ -389,19 +351,19 @@ public class SessionStore implements SignalServiceSessionStore {
} }
final var sql = """ final var sql = """
INSERT OR REPLACE INTO %s (account_id_type, recipient_id, device_id, record) INSERT OR REPLACE INTO %s (account_id_type, uuid, device_id, record)
VALUES (?, ?, ?, ?) VALUES (?, ?, ?, ?)
""".formatted(TABLE_SESSION); """.formatted(TABLE_SESSION);
try (final var statement = connection.prepareStatement(sql)) { try (final var statement = connection.prepareStatement(sql)) {
statement.setInt(1, accountIdType); statement.setInt(1, accountIdType);
statement.setLong(2, key.recipientId().id()); statement.setBytes(2, key.serviceId().toByteArray());
statement.setInt(3, key.deviceId()); statement.setInt(3, key.deviceId());
statement.setBytes(4, session.serialize()); statement.setBytes(4, session.serialize());
statement.executeUpdate(); statement.executeUpdate();
} }
} }
private void deleteAllSessions(final Connection connection, final RecipientId recipientId) throws SQLException { private void deleteAllSessions(final Connection connection, final ServiceId serviceId) throws SQLException {
synchronized (cachedSessions) { synchronized (cachedSessions) {
cachedSessions.clear(); cachedSessions.clear();
} }
@ -409,12 +371,12 @@ public class SessionStore implements SignalServiceSessionStore {
final var sql = ( final var sql = (
""" """
DELETE FROM %s AS s DELETE FROM %s AS s
WHERE s.account_id_type = ? AND s.recipient_id = ? WHERE s.account_id_type = ? AND s.uuid = ?
""" """
).formatted(TABLE_SESSION); ).formatted(TABLE_SESSION);
try (final var statement = connection.prepareStatement(sql)) { try (final var statement = connection.prepareStatement(sql)) {
statement.setInt(1, accountIdType); statement.setInt(1, accountIdType);
statement.setLong(2, recipientId.id()); statement.setBytes(2, serviceId.toByteArray());
statement.executeUpdate(); statement.executeUpdate();
} }
} }
@ -427,12 +389,12 @@ public class SessionStore implements SignalServiceSessionStore {
final var sql = ( final var sql = (
""" """
DELETE FROM %s AS s DELETE FROM %s AS s
WHERE s.account_id_type = ? AND s.recipient_id = ? AND s.device_id = ? WHERE s.account_id_type = ? AND s.uuid = ? AND s.device_id = ?
""" """
).formatted(TABLE_SESSION); ).formatted(TABLE_SESSION);
try (final var statement = connection.prepareStatement(sql)) { try (final var statement = connection.prepareStatement(sql)) {
statement.setInt(1, accountIdType); statement.setInt(1, accountIdType);
statement.setLong(2, key.recipientId().id()); statement.setBytes(2, key.serviceId().toByteArray());
statement.setInt(3, key.deviceId()); statement.setInt(3, key.deviceId());
statement.executeUpdate(); statement.executeUpdate();
} }
@ -444,5 +406,5 @@ public class SessionStore implements SignalServiceSessionStore {
&& record.getSessionVersion() == CiphertextMessage.CURRENT_VERSION; && record.getSessionVersion() == CiphertextMessage.CURRENT_VERSION;
} }
record Key(RecipientId recipientId, int deviceId) {} record Key(ServiceId serviceId, int deviceId) {}
} }

View file

@ -1,12 +1,12 @@
package org.asamk.signal.manager.util; package org.asamk.signal.manager.util;
import org.asamk.signal.manager.api.Pair; import org.asamk.signal.manager.api.Pair;
import org.asamk.signal.manager.storage.recipients.RecipientAddress;
import org.signal.libsignal.protocol.IdentityKey; import org.signal.libsignal.protocol.IdentityKey;
import org.signal.libsignal.protocol.fingerprint.Fingerprint; import org.signal.libsignal.protocol.fingerprint.Fingerprint;
import org.signal.libsignal.protocol.fingerprint.NumericFingerprintGenerator; import org.signal.libsignal.protocol.fingerprint.NumericFingerprintGenerator;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.whispersystems.signalservice.api.push.SignalServiceAddress;
import org.whispersystems.signalservice.api.util.StreamDetails; import org.whispersystems.signalservice.api.util.StreamDetails;
import java.io.BufferedInputStream; import java.io.BufferedInputStream;
@ -73,29 +73,28 @@ public class Utils {
public static Fingerprint computeSafetyNumber( public static Fingerprint computeSafetyNumber(
boolean isUuidCapable, boolean isUuidCapable,
SignalServiceAddress ownAddress, RecipientAddress ownAddress,
IdentityKey ownIdentityKey, IdentityKey ownIdentityKey,
SignalServiceAddress theirAddress, RecipientAddress theirAddress,
IdentityKey theirIdentityKey IdentityKey theirIdentityKey
) { ) {
int version; int version;
byte[] ownId; byte[] ownId;
byte[] theirId; byte[] theirId;
if (isUuidCapable) { if (!isUuidCapable && ownAddress.number().isPresent() && theirAddress.number().isPresent()) {
// Version 1: E164 user
version = 1;
ownId = ownAddress.number().get().getBytes();
theirId = theirAddress.number().get().getBytes();
} else if (isUuidCapable && theirAddress.uuid().isPresent()) {
// Version 2: UUID user // Version 2: UUID user
version = 2; version = 2;
ownId = ownAddress.getServiceId().toByteArray(); ownId = ownAddress.getServiceId().toByteArray();
theirId = theirAddress.getServiceId().toByteArray(); theirId = theirAddress.getServiceId().toByteArray();
} else { } else {
// Version 1: E164 user
version = 1;
if (ownAddress.getNumber().isEmpty() || theirAddress.getNumber().isEmpty()) {
return null; return null;
} }
ownId = ownAddress.getNumber().get().getBytes();
theirId = theirAddress.getNumber().get().getBytes();
}
return new NumericFingerprintGenerator(5200).createFor(version, return new NumericFingerprintGenerator(5200).createFor(version,
ownId, ownId,