Add PNI to recipients

This commit is contained in:
AsamK 2022-10-16 19:17:43 +02:00
parent e450f36e81
commit b9eee539bd
10 changed files with 696 additions and 110 deletions

View file

@ -54,6 +54,9 @@
<option name="TERNARY_OPERATION_WRAP" value="5" /> <option name="TERNARY_OPERATION_WRAP" value="5" />
<option name="TERNARY_OPERATION_SIGNS_ON_NEXT_LINE" value="true" /> <option name="TERNARY_OPERATION_SIGNS_ON_NEXT_LINE" value="true" />
<option name="KEEP_SIMPLE_CLASSES_IN_ONE_LINE" value="true" /> <option name="KEEP_SIMPLE_CLASSES_IN_ONE_LINE" value="true" />
<option name="ARRAY_INITIALIZER_WRAP" value="5" />
<option name="ARRAY_INITIALIZER_LBRACE_ON_NEXT_LINE" value="true" />
<option name="ARRAY_INITIALIZER_RBRACE_ON_NEXT_LINE" value="true" />
<option name="ENUM_CONSTANTS_WRAP" value="2" /> <option name="ENUM_CONSTANTS_WRAP" value="2" />
</codeStyleSettings> </codeStyleSettings>
<codeStyleSettings language="XML"> <codeStyleSettings language="XML">

View file

@ -21,6 +21,12 @@ dependencies {
implementation("org.slf4j", "slf4j-api", "2.0.3") implementation("org.slf4j", "slf4j-api", "2.0.3")
implementation("org.xerial", "sqlite-jdbc", "3.39.3.0") implementation("org.xerial", "sqlite-jdbc", "3.39.3.0")
implementation("com.zaxxer", "HikariCP", "5.0.1") implementation("com.zaxxer", "HikariCP", "5.0.1")
testImplementation("org.junit.jupiter", "junit-jupiter", "5.9.0")
}
tasks.named<Test>("test") {
useJUnitPlatform()
} }
configurations { configurations {

View file

@ -56,16 +56,19 @@ import org.whispersystems.signalservice.api.messages.SignalServiceEnvelope;
import org.whispersystems.signalservice.api.messages.SignalServiceGroup; import org.whispersystems.signalservice.api.messages.SignalServiceGroup;
import org.whispersystems.signalservice.api.messages.SignalServiceGroupContext; import org.whispersystems.signalservice.api.messages.SignalServiceGroupContext;
import org.whispersystems.signalservice.api.messages.SignalServiceGroupV2; import org.whispersystems.signalservice.api.messages.SignalServiceGroupV2;
import org.whispersystems.signalservice.api.messages.SignalServicePniSignatureMessage;
import org.whispersystems.signalservice.api.messages.SignalServiceReceiptMessage; import org.whispersystems.signalservice.api.messages.SignalServiceReceiptMessage;
import org.whispersystems.signalservice.api.messages.SignalServiceStoryMessage; import org.whispersystems.signalservice.api.messages.SignalServiceStoryMessage;
import org.whispersystems.signalservice.api.messages.multidevice.SignalServiceSyncMessage; import org.whispersystems.signalservice.api.messages.multidevice.SignalServiceSyncMessage;
import org.whispersystems.signalservice.api.messages.multidevice.StickerPackOperationMessage; import org.whispersystems.signalservice.api.messages.multidevice.StickerPackOperationMessage;
import org.whispersystems.signalservice.api.push.ACI;
import org.whispersystems.signalservice.api.push.PNI; import org.whispersystems.signalservice.api.push.PNI;
import org.whispersystems.signalservice.api.push.ServiceId; import org.whispersystems.signalservice.api.push.ServiceId;
import org.whispersystems.signalservice.api.push.SignalServiceAddress; import org.whispersystems.signalservice.api.push.SignalServiceAddress;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors; import java.util.stream.Collectors;
public final class IncomingMessageHandler { public final class IncomingMessageHandler {
@ -194,7 +197,18 @@ public final class IncomingMessageHandler {
if (content != null) { if (content != null) {
// Store uuid if we don't have it already // Store uuid if we don't have it already
// address/uuid is validated by unidentified sender certificate // address/uuid is validated by unidentified sender certificate
account.getRecipientTrustedResolver().resolveRecipientTrusted(content.getSender());
boolean handledPniSignature = false;
if (content.getPniSignatureMessage().isPresent()) {
final var message = content.getPniSignatureMessage().get();
final var senderAddress = getSenderAddress(envelope, content);
if (senderAddress != null) {
handledPniSignature = handlePniSignatureMessage(message, senderAddress);
}
}
if (!handledPniSignature) {
account.getRecipientTrustedResolver().resolveRecipientTrusted(content.getSender());
}
} }
if (envelope.isReceipt()) { if (envelope.isReceipt()) {
final var senderDeviceAddress = getSender(envelope, content); final var senderDeviceAddress = getSender(envelope, content);
@ -215,8 +229,9 @@ public final class IncomingMessageHandler {
logger.info("Ignoring a message from blocked user/group: {}", envelope.getTimestamp()); logger.info("Ignoring a message from blocked user/group: {}", envelope.getTimestamp());
return List.of(); return List.of();
} else if (notAllowedToSendToGroup) { } else if (notAllowedToSendToGroup) {
final var senderAddress = getSenderAddress(envelope, content);
logger.info("Ignoring a group message from an unauthorized sender (no member or admin): {} {}", logger.info("Ignoring a group message from an unauthorized sender (no member or admin): {} {}",
(envelope.hasSourceUuid() ? envelope.getSourceAddress() : content.getSender()).getIdentifier(), senderAddress == null ? null : senderAddress.getIdentifier(),
envelope.getTimestamp()); envelope.getTimestamp());
return List.of(); return List.of();
} else { } else {
@ -323,6 +338,32 @@ public final class IncomingMessageHandler {
return actions; return actions;
} }
private boolean handlePniSignatureMessage(
final SignalServicePniSignatureMessage message, final SignalServiceAddress senderAddress
) {
final var aci = ACI.from(senderAddress.getServiceId());
final var aciIdentity = account.getIdentityKeyStore().getIdentityInfo(aci);
final var pni = message.getPni();
final var pniIdentity = account.getIdentityKeyStore().getIdentityInfo(pni);
if (aciIdentity == null || pniIdentity == null || aci.equals(pni)) {
return false;
}
final var verified = pniIdentity.getIdentityKey()
.verifyAlternateIdentity(aciIdentity.getIdentityKey(), message.getSignature());
if (!verified) {
logger.debug("Invalid PNI signature of ACI {} with PNI {}", aci, pni);
return false;
}
logger.debug("Verified association of ACI {} with PNI {}", aci, pni);
account.getRecipientTrustedResolver()
.resolveRecipientTrusted(Optional.of(aci), Optional.of(pni), senderAddress.getNumber());
return true;
}
private void handleDecryptionErrorMessage( private void handleDecryptionErrorMessage(
final List<HandleAction> actions, final List<HandleAction> actions,
final RecipientId sender, final RecipientId sender,
@ -585,12 +626,8 @@ public final class IncomingMessageHandler {
} }
private boolean isMessageBlocked(SignalServiceEnvelope envelope, SignalServiceContent content) { private boolean isMessageBlocked(SignalServiceEnvelope envelope, SignalServiceContent content) {
SignalServiceAddress source; SignalServiceAddress source = getSenderAddress(envelope, content);
if (!envelope.isUnidentifiedSender() && envelope.hasSourceUuid()) { if (source == null) {
source = envelope.getSourceAddress();
} else if (content != null) {
source = content.getSender();
} else {
return false; return false;
} }
final var recipientId = context.getRecipientHelper().resolveRecipient(source); final var recipientId = context.getRecipientHelper().resolveRecipient(source);
@ -608,12 +645,8 @@ public final class IncomingMessageHandler {
} }
private boolean isNotAllowedToSendToGroup(SignalServiceEnvelope envelope, SignalServiceContent content) { private boolean isNotAllowedToSendToGroup(SignalServiceEnvelope envelope, SignalServiceContent content) {
SignalServiceAddress source; SignalServiceAddress source = getSenderAddress(envelope, content);
if (!envelope.isUnidentifiedSender() && envelope.hasSourceUuid()) { if (source == null) {
source = envelope.getSourceAddress();
} else if (content != null) {
source = content.getSender();
} else {
return false; return false;
} }
@ -853,6 +886,16 @@ public final class IncomingMessageHandler {
this.account.getProfileStore().storeProfileKey(source, profileKey); this.account.getProfileStore().storeProfileKey(source, profileKey);
} }
private SignalServiceAddress getSenderAddress(SignalServiceEnvelope envelope, SignalServiceContent content) {
if (!envelope.isUnidentifiedSender() && envelope.hasSourceUuid()) {
return envelope.getSourceAddress();
} else if (content != null) {
return content.getSender();
} else {
return null;
}
}
private DeviceAddress getSender(SignalServiceEnvelope envelope, SignalServiceContent content) { private DeviceAddress getSender(SignalServiceEnvelope envelope, SignalServiceContent content) {
if (!envelope.isUnidentifiedSender() && envelope.hasSourceUuid()) { if (!envelope.isUnidentifiedSender() && envelope.hasSourceUuid()) {
return new DeviceAddress(context.getRecipientHelper().resolveRecipient(envelope.getSourceAddress()), return new DeviceAddress(context.getRecipientHelper().resolveRecipient(envelope.getSourceAddress()),

View file

@ -22,7 +22,7 @@ import java.sql.SQLException;
public class AccountDatabase extends Database { public class AccountDatabase extends Database {
private final static Logger logger = LoggerFactory.getLogger(AccountDatabase.class); private final static Logger logger = LoggerFactory.getLogger(AccountDatabase.class);
private static final long DATABASE_VERSION = 10; private static final long DATABASE_VERSION = 11;
private AccountDatabase(final HikariDataSource dataSource) { private AccountDatabase(final HikariDataSource dataSource) {
super(logger, DATABASE_VERSION, dataSource); super(logger, DATABASE_VERSION, dataSource);
@ -288,5 +288,13 @@ public class AccountDatabase extends Database {
"""); """);
} }
} }
if (oldVersion < 11) {
logger.debug("Updating database: Adding pni field");
try (final var statement = connection.createStatement()) {
statement.executeUpdate("""
ALTER TABLE recipient ADD COLUMN pni BLOB;
""");
}
}
} }
} }

View file

@ -102,7 +102,7 @@ public class SignalAccount implements Closeable {
private final static Logger logger = LoggerFactory.getLogger(SignalAccount.class); private final static Logger logger = LoggerFactory.getLogger(SignalAccount.class);
private static final int MINIMUM_STORAGE_VERSION = 1; private static final int MINIMUM_STORAGE_VERSION = 1;
private static final int CURRENT_STORAGE_VERSION = 5; private static final int CURRENT_STORAGE_VERSION = 6;
private final Object LOCK = new Object(); private final Object LOCK = new Object();
@ -634,6 +634,9 @@ public class SignalAccount implements Closeable {
migratedLegacyConfig = true; migratedLegacyConfig = true;
} }
} }
if (previousStorageVersion < 6) {
getRecipientTrustedResolver().resolveSelfRecipientTrusted(getSelfRecipientAddress());
}
final var legacyAciPreKeysPath = getAciPreKeysPath(dataPath, accountPath); final var legacyAciPreKeysPath = getAciPreKeysPath(dataPath, accountPath);
if (legacyAciPreKeysPath.exists()) { if (legacyAciPreKeysPath.exists()) {
LegacyPreKeyStore.migrate(legacyAciPreKeysPath, getAciPreKeyStore()); LegacyPreKeyStore.migrate(legacyAciPreKeysPath, getAciPreKeyStore());

View file

@ -0,0 +1,140 @@
package org.asamk.signal.manager.storage.recipients;
import org.asamk.signal.manager.api.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.sql.SQLException;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
public class MergeRecipientHelper {
private final static Logger logger = LoggerFactory.getLogger(MergeRecipientHelper.class);
static Pair<RecipientId, List<RecipientId>> resolveRecipientTrustedLocked(
Store store, RecipientAddress address
) throws SQLException {
// address has serviceId and number, optionally also pni
final var recipients = store.findAllByAddress(address);
if (recipients.isEmpty()) {
logger.debug("Got new recipient, serviceId, PNI and number are unknown");
return new Pair<>(store.addNewRecipient(address), List.of());
}
if (recipients.size() == 1) {
final var recipient = recipients.stream().findFirst().get();
if (recipient.address().hasIdentifiersOf(address)) {
return new Pair<>(recipient.id(), List.of());
}
if (recipient.address().serviceId().isEmpty() || (
recipient.address().serviceId().equals(address.serviceId())
) || (
recipient.address().pni().isPresent() && recipient.address().pni().equals(address.serviceId())
) || (
recipient.address().serviceId().equals(address.pni())
) || (
address.pni().isPresent() && address.pni().equals(recipient.address().pni())
)) {
logger.debug("Got existing recipient {}, updating with high trust address", recipient.id());
store.updateRecipientAddress(recipient.id(), recipient.address().withIdentifiersFrom(address));
return new Pair<>(recipient.id(), List.of());
}
logger.debug(
"Got recipient {} existing with number/pni, but different serviceId, so stripping its number and adding new recipient",
recipient.id());
store.updateRecipientAddress(recipient.id(), recipient.address().removeIdentifiersFrom(address));
return new Pair<>(store.addNewRecipient(address), List.of());
}
var resultingRecipient = recipients.stream()
.filter(r -> r.address().serviceId().equals(address.serviceId()) || r.address()
.pni()
.equals(address.serviceId()))
.findFirst();
if (resultingRecipient.isEmpty() && address.pni().isPresent()) {
resultingRecipient = recipients.stream().filter(r -> r.address().serviceId().equals(address.pni()) || (
address.serviceId().equals(address.pni()) && r.address().pni().equals(address.pni())
)).findFirst();
}
final Set<RecipientWithAddress> remainingRecipients;
if (resultingRecipient.isEmpty()) {
remainingRecipients = recipients;
} else {
remainingRecipients = new HashSet<>(recipients);
remainingRecipients.remove(resultingRecipient.get());
}
final var recipientsToBeMerged = new HashSet<RecipientWithAddress>();
final var recipientsToBeStripped = new HashSet<RecipientWithAddress>();
for (final var recipient : remainingRecipients) {
if (!recipient.address().hasAdditionalIdentifiersThan(address)) {
recipientsToBeMerged.add(recipient);
continue;
}
if (recipient.address().hasOnlyPniAndNumber()) {
// PNI and phone number are linked by the server
recipientsToBeMerged.add(recipient);
continue;
}
recipientsToBeStripped.add(recipient);
}
logger.debug("Got separate recipients for high trust identifiers {}, need to merge ({}) and strip ({})",
address,
recipientsToBeMerged.stream().map(r -> r.id().toString()).collect(Collectors.joining(", ")),
recipientsToBeStripped.stream().map(r -> r.id().toString()).collect(Collectors.joining(", ")));
RecipientAddress finalAddress = resultingRecipient.map(RecipientWithAddress::address).orElse(null);
for (final var recipient : recipientsToBeMerged) {
if (finalAddress == null) {
finalAddress = recipient.address();
} else {
finalAddress = finalAddress.withIdentifiersFrom(recipient.address());
}
store.removeRecipientAddress(recipient.id());
}
if (finalAddress == null) {
finalAddress = address;
} else {
finalAddress = finalAddress.withIdentifiersFrom(address);
}
for (final var recipient : recipientsToBeStripped) {
store.updateRecipientAddress(recipient.id(), recipient.address().removeIdentifiersFrom(address));
}
// Create fixed RecipientIds that won't update its id after merged
final var toBeMergedRecipientIds = recipientsToBeMerged.stream()
.map(r -> new RecipientId(r.id().id(), null))
.toList();
if (resultingRecipient.isPresent()) {
store.updateRecipientAddress(resultingRecipient.get().id(), finalAddress);
return new Pair<>(resultingRecipient.get().id(), toBeMergedRecipientIds);
}
return new Pair<>(store.addNewRecipient(finalAddress), toBeMergedRecipientIds);
}
public interface Store {
Set<RecipientWithAddress> findAllByAddress(final RecipientAddress address) throws SQLException;
RecipientId addNewRecipient(final RecipientAddress address) throws SQLException;
void updateRecipientAddress(RecipientId recipientId, final RecipientAddress address) throws SQLException;
void removeRecipientAddress(RecipientId recipientId) throws SQLException;
}
}

View file

@ -24,6 +24,14 @@ public record RecipientAddress(Optional<ServiceId> serviceId, Optional<PNI> pni,
if (serviceId.isEmpty() && pni.isPresent()) { if (serviceId.isEmpty() && pni.isPresent()) {
serviceId = Optional.of(pni.get()); serviceId = Optional.of(pni.get());
} }
if (serviceId.isPresent() && serviceId.get() instanceof PNI sPNI) {
if (pni.isPresent() && !sPNI.equals(pni.get())) {
throw new AssertionError("Must not have two different PNIs!");
}
if (pni.isEmpty()) {
pni = Optional.of(sPNI);
}
}
if (serviceId.isEmpty() && number.isEmpty()) { if (serviceId.isEmpty() && number.isEmpty()) {
throw new AssertionError("Must have either a ServiceId or E164 number!"); throw new AssertionError("Must have either a ServiceId or E164 number!");
} }
@ -49,6 +57,22 @@ public record RecipientAddress(Optional<ServiceId> serviceId, Optional<PNI> pni,
this(Optional.of(serviceId), Optional.empty()); this(Optional.of(serviceId), Optional.empty());
} }
public RecipientAddress withIdentifiersFrom(RecipientAddress address) {
return new RecipientAddress((
this.serviceId.isEmpty() || this.isServiceIdPNI() || this.serviceId.equals(address.pni)
) && !address.isServiceIdPNI() ? address.serviceId : this.serviceId,
address.pni.or(this::pni),
address.number.or(this::number));
}
public RecipientAddress removeIdentifiersFrom(RecipientAddress address) {
return new RecipientAddress(address.serviceId.equals(this.serviceId) || address.pni.equals(this.serviceId)
? Optional.empty()
: this.serviceId,
address.pni.equals(this.pni) || address.serviceId.equals(this.pni) ? Optional.empty() : this.pni,
address.number.equals(this.number) ? Optional.empty() : this.number);
}
public ServiceId getServiceId() { public ServiceId getServiceId() {
return serviceId.orElse(ServiceId.UNKNOWN); return serviceId.orElse(ServiceId.UNKNOWN);
} }
@ -89,6 +113,42 @@ public record RecipientAddress(Optional<ServiceId> serviceId, Optional<PNI> pni,
); );
} }
public boolean hasSingleIdentifier() {
return serviceId().isEmpty() || number.isEmpty();
}
public boolean hasIdentifiersOf(RecipientAddress address) {
return (address.serviceId.isEmpty() || address.serviceId.equals(serviceId) || address.serviceId.equals(pni))
&& (address.pni.isEmpty() || address.pni.equals(pni))
&& (address.number.isEmpty() || address.number.equals(number));
}
public boolean hasAdditionalIdentifiersThan(RecipientAddress address) {
return (
serviceId.isPresent() && (
address.serviceId.isEmpty() || (
!address.serviceId.equals(serviceId) && !address.pni.equals(serviceId)
)
)
) || (
pni.isPresent() && !address.serviceId.equals(pni) && (
address.pni.isEmpty() || !address.pni.equals(pni)
)
) || (
number.isPresent() && (
address.number.isEmpty() || !address.number.equals(number)
)
);
}
public boolean hasOnlyPniAndNumber() {
return pni.isPresent() && serviceId.equals(pni) && number.isPresent();
}
public boolean isServiceIdPNI() {
return serviceId.isPresent() && (pni.isPresent() && serviceId.equals(pni));
}
public SignalServiceAddress toSignalServiceAddress() { public SignalServiceAddress toSignalServiceAddress() {
return new SignalServiceAddress(getServiceId(), number); return new SignalServiceAddress(getServiceId(), number);
} }

View file

@ -53,6 +53,7 @@ public class RecipientStore implements RecipientIdCreator, RecipientResolver, Re
_id INTEGER PRIMARY KEY AUTOINCREMENT, _id INTEGER PRIMARY KEY AUTOINCREMENT,
number TEXT UNIQUE, number TEXT UNIQUE,
uuid BLOB UNIQUE, uuid BLOB UNIQUE,
pni BLOB UNIQUE,
profile_key BLOB, profile_key BLOB,
profile_key_credential BLOB, profile_key_credential BLOB,
@ -92,7 +93,7 @@ public class RecipientStore implements RecipientIdCreator, RecipientResolver, Re
public RecipientAddress resolveRecipientAddress(RecipientId recipientId) { public RecipientAddress resolveRecipientAddress(RecipientId recipientId) {
final var sql = ( final var sql = (
""" """
SELECT r.number, r.uuid SELECT r.number, r.uuid, r.pni
FROM %s r FROM %s r
WHERE r._id = ? WHERE r._id = ?
""" """
@ -246,7 +247,7 @@ public class RecipientStore implements RecipientIdCreator, RecipientResolver, Re
final Optional<ACI> aci, final Optional<PNI> pni, final Optional<String> number final Optional<ACI> aci, final Optional<PNI> pni, final Optional<String> number
) { ) {
final var serviceId = aci.map(a -> (ServiceId) a).or(() -> pni); final var serviceId = aci.map(a -> (ServiceId) a).or(() -> pni);
return resolveRecipientTrusted(new RecipientAddress(serviceId, number), false); return resolveRecipientTrusted(new RecipientAddress(serviceId, pni, number), false);
} }
@Override @Override
@ -308,7 +309,7 @@ public class RecipientStore implements RecipientIdCreator, RecipientResolver, Re
final var sql = ( final var sql = (
""" """
SELECT r._id, SELECT r._id,
r.number, r.uuid, r.number, r.uuid, r.pni,
r.profile_key, r.profile_key_credential, r.profile_key, r.profile_key_credential,
r.given_name, r.family_name, r.expiration_time, r.profile_sharing, r.color, r.blocked, r.archived, r.given_name, r.family_name, r.expiration_time, r.profile_sharing, r.color, r.blocked, r.archived,
r.profile_last_update_timestamp, r.profile_given_name, r.profile_family_name, r.profile_about, r.profile_about_emoji, r.profile_avatar_url_path, r.profile_mobile_coin_address, r.profile_unidentified_access_mode, r.profile_capabilities r.profile_last_update_timestamp, r.profile_given_name, r.profile_family_name, r.profile_about, r.profile_about_emoji, r.profile_avatar_url_path, r.profile_mobile_coin_address, r.profile_unidentified_access_mode, r.profile_capabilities
@ -601,21 +602,33 @@ public class RecipientStore implements RecipientIdCreator, RecipientResolver, Re
} }
private RecipientId resolveRecipientTrusted(RecipientAddress address, boolean isSelf) { private RecipientId resolveRecipientTrusted(RecipientAddress address, boolean isSelf) {
final Pair<RecipientId, Optional<RecipientId>> pair; final Pair<RecipientId, List<RecipientId>> pair;
synchronized (recipientsLock) { synchronized (recipientsLock) {
try (final var connection = database.getConnection()) { try (final var connection = database.getConnection()) {
connection.setAutoCommit(false); connection.setAutoCommit(false);
pair = resolveRecipientTrustedLocked(connection, address, isSelf); if (address.hasSingleIdentifier() || (
!isSelf && selfAddressProvider.getSelfAddress().matches(address)
)) {
pair = new Pair<>(resolveRecipientLocked(connection, address), List.of());
} else {
pair = MergeRecipientHelper.resolveRecipientTrustedLocked(new HelperStore(connection), address);
for (final var toBeMergedRecipientId : pair.second()) {
mergeRecipientsLocked(connection, pair.first(), toBeMergedRecipientId);
}
}
connection.commit(); connection.commit();
} catch (SQLException e) { } catch (SQLException e) {
throw new RuntimeException("Failed update recipient store", e); throw new RuntimeException("Failed update recipient store", e);
} }
} }
if (pair.second().isPresent()) { if (pair.second().size() > 0) {
try (final var connection = database.getConnection()) { try (final var connection = database.getConnection()) {
recipientMergeHandler.mergeRecipients(connection, pair.first(), pair.second().get()); for (final var toBeMergedRecipientId : pair.second()) {
deleteRecipient(connection, pair.second().get()); recipientMergeHandler.mergeRecipients(connection, pair.first(), toBeMergedRecipientId);
deleteRecipient(connection, toBeMergedRecipientId);
}
} catch (SQLException e) { } catch (SQLException e) {
throw new RuntimeException("Failed update recipient store", e); throw new RuntimeException("Failed update recipient store", e);
} }
@ -623,82 +636,6 @@ public class RecipientStore implements RecipientIdCreator, RecipientResolver, Re
return pair.first(); return pair.first();
} }
private Pair<RecipientId, Optional<RecipientId>> resolveRecipientTrustedLocked(
Connection connection, RecipientAddress address, boolean isSelf
) throws SQLException {
if (!isSelf) {
if (selfAddressProvider.getSelfAddress().matches(address)) {
return new Pair<>(resolveRecipientLocked(connection, address), Optional.empty());
}
}
final var byNumber = address.number().isEmpty()
? Optional.<RecipientWithAddress>empty()
: findByNumber(connection, address.number().get());
final var byUuid = address.serviceId().isEmpty()
? Optional.<RecipientWithAddress>empty()
: findByServiceId(connection, address.serviceId().get());
if (byNumber.isEmpty() && byUuid.isEmpty()) {
logger.debug("Got new recipient, both uuid and number are unknown");
return new Pair<>(addNewRecipient(connection, address), Optional.empty());
}
if (address.serviceId().isEmpty() || address.number().isEmpty() || byNumber.equals(byUuid)) {
return new Pair<>(byUuid.or(() -> byNumber).map(RecipientWithAddress::id).get(), Optional.empty());
}
if (byNumber.isEmpty()) {
logger.debug("Got recipient {} existing with uuid, updating with high trust number", byUuid.get().id());
updateRecipientAddress(connection, byUuid.get().id(), address);
return new Pair<>(byUuid.get().id(), Optional.empty());
}
final var byNumberRecipient = byNumber.get();
if (byUuid.isEmpty()) {
if (byNumberRecipient.address().serviceId().isPresent()) {
logger.debug(
"Got recipient {} existing with number, but different uuid, so stripping its number and adding new recipient",
byNumberRecipient.id());
updateRecipientAddress(connection,
byNumberRecipient.id(),
new RecipientAddress(byNumberRecipient.address().serviceId().get()));
return new Pair<>(addNewRecipient(connection, address), Optional.empty());
}
logger.debug("Got recipient {} existing with number and no uuid, updating with high trust uuid",
byNumberRecipient.id());
updateRecipientAddress(connection, byNumberRecipient.id(), address);
return new Pair<>(byNumberRecipient.id(), Optional.empty());
}
final var byUuidRecipient = byUuid.get();
if (byNumberRecipient.address().serviceId().isPresent()) {
logger.debug(
"Got separate recipients for high trust number {} and uuid {}, recipient for number has different uuid, so stripping its number",
byNumberRecipient.id(),
byUuidRecipient.id());
updateRecipientAddress(connection,
byNumberRecipient.id(),
new RecipientAddress(byNumberRecipient.address().serviceId().get()));
updateRecipientAddress(connection, byUuidRecipient.id(), address);
return new Pair<>(byUuidRecipient.id(), Optional.empty());
}
logger.debug("Got separate recipients for high trust number {} and uuid {}, need to merge them",
byNumberRecipient.id(),
byUuidRecipient.id());
// Create a fixed RecipientId that won't update its id after merge
final var toBeMergedRecipientId = new RecipientId(byNumberRecipient.id().id(), null);
mergeRecipientsLocked(connection, byUuidRecipient.id(), toBeMergedRecipientId);
removeRecipientAddress(connection, toBeMergedRecipientId);
updateRecipientAddress(connection, byUuidRecipient.id(), address);
return new Pair<>(byUuidRecipient.id(), Optional.of(toBeMergedRecipientId));
}
private RecipientId resolveRecipientLocked( private RecipientId resolveRecipientLocked(
Connection connection, RecipientAddress address Connection connection, RecipientAddress address
) throws SQLException { ) throws SQLException {
@ -762,13 +699,14 @@ public class RecipientStore implements RecipientIdCreator, RecipientResolver, Re
) throws SQLException { ) throws SQLException {
final var sql = ( final var sql = (
""" """
INSERT INTO %s (number, uuid) INSERT INTO %s (number, uuid, pni)
VALUES (?, ?) VALUES (?, ?, ?)
""" """
).formatted(TABLE_RECIPIENT); ).formatted(TABLE_RECIPIENT);
try (final var statement = connection.prepareStatement(sql)) { try (final var statement = connection.prepareStatement(sql)) {
statement.setString(1, address.number().orElse(null)); statement.setString(1, address.number().orElse(null));
statement.setBytes(2, address.serviceId().map(ServiceId::uuid).map(UuidUtil::toByteArray).orElse(null)); statement.setBytes(2, address.serviceId().map(ServiceId::uuid).map(UuidUtil::toByteArray).orElse(null));
statement.setBytes(3, address.pni().map(PNI::uuid).map(UuidUtil::toByteArray).orElse(null));
statement.executeUpdate(); statement.executeUpdate();
final var generatedKeys = statement.getGeneratedKeys(); final var generatedKeys = statement.getGeneratedKeys();
if (generatedKeys.next()) { if (generatedKeys.next()) {
@ -785,7 +723,7 @@ public class RecipientStore implements RecipientIdCreator, RecipientResolver, Re
final var sql = ( final var sql = (
""" """
UPDATE %s UPDATE %s
SET number = NULL, uuid = NULL SET number = NULL, uuid = NULL, pni = NULL
WHERE _id = ? WHERE _id = ?
""" """
).formatted(TABLE_RECIPIENT); ).formatted(TABLE_RECIPIENT);
@ -801,14 +739,15 @@ public class RecipientStore implements RecipientIdCreator, RecipientResolver, Re
final var sql = ( final var sql = (
""" """
UPDATE %s UPDATE %s
SET number = ?, uuid = ? SET number = ?, uuid = ?, pni = ?
WHERE _id = ? WHERE _id = ?
""" """
).formatted(TABLE_RECIPIENT); ).formatted(TABLE_RECIPIENT);
try (final var statement = connection.prepareStatement(sql)) { try (final var statement = connection.prepareStatement(sql)) {
statement.setString(1, address.number().orElse(null)); statement.setString(1, address.number().orElse(null));
statement.setBytes(2, address.serviceId().map(ServiceId::uuid).map(UuidUtil::toByteArray).orElse(null)); statement.setBytes(2, address.serviceId().map(ServiceId::uuid).map(UuidUtil::toByteArray).orElse(null));
statement.setLong(3, recipientId.id()); statement.setBytes(3, address.pni().map(PNI::uuid).map(UuidUtil::toByteArray).orElse(null));
statement.setLong(4, recipientId.id());
statement.executeUpdate(); statement.executeUpdate();
} }
} }
@ -861,9 +800,10 @@ public class RecipientStore implements RecipientIdCreator, RecipientResolver, Re
final Connection connection, final String number final Connection connection, final String number
) throws SQLException { ) throws SQLException {
final var sql = """ final var sql = """
SELECT r._id, r.number, r.uuid SELECT r._id, r.number, r.uuid, r.pni
FROM %s r FROM %s r
WHERE r.number = ? WHERE r.number = ?
LIMIT 1
""".formatted(TABLE_RECIPIENT); """.formatted(TABLE_RECIPIENT);
try (final var statement = connection.prepareStatement(sql)) { try (final var statement = connection.prepareStatement(sql)) {
statement.setString(1, number); statement.setString(1, number);
@ -875,9 +815,10 @@ public class RecipientStore implements RecipientIdCreator, RecipientResolver, Re
final Connection connection, final ServiceId serviceId final Connection connection, final ServiceId serviceId
) throws SQLException { ) throws SQLException {
final var sql = """ final var sql = """
SELECT r._id, r.number, r.uuid SELECT r._id, r.number, r.uuid, r.pni
FROM %s r FROM %s r
WHERE r.uuid = ? WHERE r.uuid = ? OR r.pni = ?
LIMIT 1
""".formatted(TABLE_RECIPIENT); """.formatted(TABLE_RECIPIENT);
try (final var statement = connection.prepareStatement(sql)) { try (final var statement = connection.prepareStatement(sql)) {
statement.setBytes(1, UuidUtil.toByteArray(serviceId.uuid())); statement.setBytes(1, UuidUtil.toByteArray(serviceId.uuid()));
@ -885,6 +826,25 @@ public class RecipientStore implements RecipientIdCreator, RecipientResolver, Re
} }
} }
private Set<RecipientWithAddress> findAllByAddress(
final Connection connection, final RecipientAddress address
) throws SQLException {
final var sql = """
SELECT r._id, r.number, r.uuid, r.pni
FROM %s r
WHERE r.uuid = ?1 OR r.pni = ?1 OR
r.uuid = ?2 OR r.pni = ?2 OR
r.number = ?3
""".formatted(TABLE_RECIPIENT);
try (final var statement = connection.prepareStatement(sql)) {
statement.setBytes(1, address.serviceId().map(ServiceId::uuid).map(UuidUtil::toByteArray).orElse(null));
statement.setBytes(2, address.pni().map(ServiceId::uuid).map(UuidUtil::toByteArray).orElse(null));
statement.setString(3, address.number().orElse(null));
return Utils.executeQueryForStream(statement, this::getRecipientWithAddressFromResultSet)
.collect(Collectors.toSet());
}
}
private Contact getContact(final Connection connection, final RecipientId recipientId) throws SQLException { private Contact getContact(final Connection connection, final RecipientId recipientId) throws SQLException {
final var sql = ( final var sql = (
""" """
@ -946,8 +906,9 @@ public class RecipientStore implements RecipientIdCreator, RecipientResolver, Re
private RecipientAddress getRecipientAddressFromResultSet(ResultSet resultSet) throws SQLException { private RecipientAddress getRecipientAddressFromResultSet(ResultSet resultSet) throws SQLException {
final var serviceId = Optional.ofNullable(resultSet.getBytes("uuid")).map(ServiceId::parseOrNull); final var serviceId = Optional.ofNullable(resultSet.getBytes("uuid")).map(ServiceId::parseOrNull);
final var pni = Optional.ofNullable(resultSet.getBytes("pni")).map(PNI::parseOrNull);
final var number = Optional.ofNullable(resultSet.getString("number")); final var number = Optional.ofNullable(resultSet.getString("number"));
return new RecipientAddress(serviceId, Optional.empty(), number); return new RecipientAddress(serviceId, pni, number);
} }
private RecipientId getRecipientIdFromResultSet(ResultSet resultSet) throws SQLException { private RecipientId getRecipientIdFromResultSet(ResultSet resultSet) throws SQLException {
@ -1032,5 +993,34 @@ public class RecipientStore implements RecipientIdCreator, RecipientResolver, Re
) throws SQLException; ) throws SQLException;
} }
private record RecipientWithAddress(RecipientId id, RecipientAddress address) {} private class HelperStore implements MergeRecipientHelper.Store {
private final Connection connection;
public HelperStore(final Connection connection) {
this.connection = connection;
}
@Override
public Set<RecipientWithAddress> findAllByAddress(final RecipientAddress address) throws SQLException {
return RecipientStore.this.findAllByAddress(connection, address);
}
@Override
public RecipientId addNewRecipient(final RecipientAddress address) throws SQLException {
return RecipientStore.this.addNewRecipient(connection, address);
}
@Override
public void updateRecipientAddress(
final RecipientId recipientId, final RecipientAddress address
) throws SQLException {
RecipientStore.this.updateRecipientAddress(connection, recipientId, address);
}
@Override
public void removeRecipientAddress(final RecipientId recipientId) throws SQLException {
RecipientStore.this.removeRecipientAddress(connection, recipientId);
}
}
} }

View file

@ -0,0 +1,3 @@
package org.asamk.signal.manager.storage.recipients;
record RecipientWithAddress(RecipientId id, RecipientAddress address) {}

View file

@ -0,0 +1,330 @@
package org.asamk.signal.manager.storage.recipients;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.whispersystems.signalservice.api.push.PNI;
import org.whispersystems.signalservice.api.push.ServiceId;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Set;
import java.util.UUID;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import static org.junit.jupiter.api.Assertions.assertEquals;
class MergeRecipientHelperTest {
static final ServiceId SERVICE_ID_A = ServiceId.from(UUID.randomUUID());
static final ServiceId SERVICE_ID_B = ServiceId.from(UUID.randomUUID());
static final ServiceId SERVICE_ID_C = ServiceId.from(UUID.randomUUID());
static final PNI PNI_A = PNI.from(UUID.randomUUID());
static final PNI PNI_B = PNI.from(UUID.randomUUID());
static final PNI PNI_C = PNI.from(UUID.randomUUID());
static final String NUMBER_A = "+AAA";
static final String NUMBER_B = "+BBB";
static final String NUMBER_C = "+CCC";
static final PartialAddresses ADDR_A = new PartialAddresses(SERVICE_ID_A, PNI_A, NUMBER_A);
static final PartialAddresses ADDR_B = new PartialAddresses(SERVICE_ID_B, PNI_B, NUMBER_B);
static T[] testInstancesNone = new T[]{
// 1
new T(Set.of(), ADDR_A.FULL, Set.of(rec(1000000, ADDR_A.FULL))),
new T(Set.of(), ADDR_A.ACI_NUM, Set.of(rec(1000000, ADDR_A.ACI_NUM))),
new T(Set.of(), ADDR_A.ACI_PNI, Set.of(rec(1000000, ADDR_A.ACI_PNI))),
new T(Set.of(), ADDR_A.PNI_S_NUM, Set.of(rec(1000000, ADDR_A.PNI_S_NUM))),
new T(Set.of(), ADDR_A.PNI_NUM, Set.of(rec(1000000, ADDR_A.PNI_NUM))),
};
static T[] testInstancesSingle = new T[]{
// 1
new T(Set.of(rec(1, ADDR_A.FULL)), ADDR_A.FULL, Set.of(rec(1, ADDR_A.FULL))),
new T(Set.of(rec(1, ADDR_A.ACI)), ADDR_A.FULL, Set.of(rec(1, ADDR_A.FULL))),
new T(Set.of(rec(1, ADDR_A.PNI)), ADDR_A.FULL, Set.of(rec(1, ADDR_A.FULL))),
new T(Set.of(rec(1, ADDR_A.PNI_S)), ADDR_A.FULL, Set.of(rec(1, ADDR_A.FULL))),
new T(Set.of(rec(1, ADDR_A.NUM)), ADDR_A.FULL, Set.of(rec(1, ADDR_A.FULL))),
new T(Set.of(rec(1, ADDR_A.ACI_NUM)), ADDR_A.FULL, Set.of(rec(1, ADDR_A.FULL))),
new T(Set.of(rec(1, ADDR_A.PNI_NUM)), ADDR_A.FULL, Set.of(rec(1, ADDR_A.FULL))),
new T(Set.of(rec(1, ADDR_A.PNI_S_NUM)), ADDR_A.FULL, Set.of(rec(1, ADDR_A.FULL))),
new T(Set.of(rec(1, ADDR_A.ACI_PNI)), ADDR_A.FULL, Set.of(rec(1, ADDR_A.FULL))),
// 10
new T(Set.of(rec(1, ADDR_A.FULL)), ADDR_A.ACI_NUM, Set.of(rec(1, ADDR_A.FULL))),
new T(Set.of(rec(1, ADDR_A.ACI)), ADDR_A.ACI_NUM, Set.of(rec(1, ADDR_A.ACI_NUM))),
new T(Set.of(rec(1, ADDR_A.PNI)), ADDR_A.ACI_NUM, Set.of(rec(1, ADDR_A.PNI), rec(1000000, ADDR_A.ACI_NUM))),
new T(Set.of(rec(1, ADDR_A.PNI_S)),
ADDR_A.ACI_NUM,
Set.of(rec(1, ADDR_A.PNI_S), rec(1000000, ADDR_A.ACI_NUM))),
new T(Set.of(rec(1, ADDR_A.NUM)), ADDR_A.ACI_NUM, Set.of(rec(1, ADDR_A.ACI_NUM))),
new T(Set.of(rec(1, ADDR_A.ACI_NUM)), ADDR_A.ACI_NUM, Set.of(rec(1, ADDR_A.ACI_NUM))),
new T(Set.of(rec(1, ADDR_A.PNI_NUM)),
ADDR_A.ACI_NUM,
Set.of(rec(1, ADDR_A.PNI), rec(1000000, ADDR_A.ACI_NUM))),
new T(Set.of(rec(1, ADDR_A.PNI_S_NUM)),
ADDR_A.ACI_NUM,
Set.of(rec(1, ADDR_A.PNI_S), rec(1000000, ADDR_A.ACI_NUM))),
new T(Set.of(rec(1, ADDR_A.ACI_PNI)), ADDR_A.ACI_NUM, Set.of(rec(1, ADDR_A.FULL))),
// 19
new T(Set.of(rec(1, ADDR_A.FULL)), ADDR_A.PNI_NUM, Set.of(rec(1, ADDR_A.FULL))),
new T(Set.of(rec(1, ADDR_A.ACI)), ADDR_A.PNI_NUM, Set.of(rec(1, ADDR_A.ACI), rec(1000000, ADDR_A.PNI_NUM))),
new T(Set.of(rec(1, ADDR_A.PNI)), ADDR_A.PNI_NUM, Set.of(rec(1, ADDR_A.PNI_NUM))),
new T(Set.of(rec(1, ADDR_A.PNI_S)), ADDR_A.PNI_NUM, Set.of(rec(1, ADDR_A.PNI_NUM))),
new T(Set.of(rec(1, ADDR_A.NUM)), ADDR_A.PNI_NUM, Set.of(rec(1, ADDR_A.PNI_NUM))),
new T(Set.of(rec(1, ADDR_A.ACI_NUM)),
ADDR_A.PNI_NUM,
Set.of(rec(1, ADDR_A.ACI), rec(1000000, ADDR_A.PNI_NUM))),
new T(Set.of(rec(1, ADDR_A.PNI_NUM)), ADDR_A.PNI_NUM, Set.of(rec(1, ADDR_A.PNI_NUM))),
new T(Set.of(rec(1, ADDR_A.PNI_S_NUM)), ADDR_A.PNI_NUM, Set.of(rec(1, ADDR_A.PNI_NUM))),
new T(Set.of(rec(1, ADDR_A.ACI_PNI)), ADDR_A.PNI_NUM, Set.of(rec(1, ADDR_A.FULL))),
// 28
new T(Set.of(rec(1, ADDR_A.FULL)), ADDR_A.PNI_S_NUM, Set.of(rec(1, ADDR_A.FULL))),
new T(Set.of(rec(1, ADDR_A.ACI)),
ADDR_A.PNI_S_NUM,
Set.of(rec(1, ADDR_A.ACI), rec(1000000, ADDR_A.PNI_S_NUM))),
new T(Set.of(rec(1, ADDR_A.PNI)), ADDR_A.PNI_S_NUM, Set.of(rec(1, ADDR_A.PNI_NUM))),
new T(Set.of(rec(1, ADDR_A.PNI_S)), ADDR_A.PNI_S_NUM, Set.of(rec(1, ADDR_A.PNI_S_NUM))),
new T(Set.of(rec(1, ADDR_A.NUM)), ADDR_A.PNI_S_NUM, Set.of(rec(1, ADDR_A.PNI_S_NUM))),
new T(Set.of(rec(1, ADDR_A.ACI_NUM)),
ADDR_A.PNI_S_NUM,
Set.of(rec(1, ADDR_A.ACI), rec(1000000, ADDR_A.PNI_S_NUM))),
new T(Set.of(rec(1, ADDR_A.PNI_NUM)), ADDR_A.PNI_S_NUM, Set.of(rec(1, ADDR_A.PNI_NUM))),
new T(Set.of(rec(1, ADDR_A.PNI_S_NUM)), ADDR_A.PNI_S_NUM, Set.of(rec(1, ADDR_A.PNI_S_NUM))),
new T(Set.of(rec(1, ADDR_A.ACI_PNI)), ADDR_A.PNI_S_NUM, Set.of(rec(1, ADDR_A.FULL))),
// 37
new T(Set.of(rec(1, ADDR_A.FULL)), ADDR_A.ACI_PNI, Set.of(rec(1, ADDR_A.FULL))),
new T(Set.of(rec(1, ADDR_A.ACI)), ADDR_A.ACI_PNI, Set.of(rec(1, ADDR_A.ACI_PNI))),
new T(Set.of(rec(1, ADDR_A.PNI)), ADDR_A.ACI_PNI, Set.of(rec(1, ADDR_A.ACI_PNI))),
new T(Set.of(rec(1, ADDR_A.PNI_S)), ADDR_A.ACI_PNI, Set.of(rec(1, ADDR_A.ACI_PNI))),
new T(Set.of(rec(1, ADDR_A.NUM)), ADDR_A.ACI_PNI, Set.of(rec(1, ADDR_A.NUM), rec(1000000, ADDR_A.ACI_PNI))),
new T(Set.of(rec(1, ADDR_A.ACI_NUM)), ADDR_A.ACI_PNI, Set.of(rec(1, ADDR_A.FULL))),
new T(Set.of(rec(1, ADDR_A.PNI_NUM)), ADDR_A.ACI_PNI, Set.of(rec(1, ADDR_A.FULL))),
new T(Set.of(rec(1, ADDR_A.PNI_S_NUM)), ADDR_A.ACI_PNI, Set.of(rec(1, ADDR_A.FULL))),
new T(Set.of(rec(1, ADDR_A.ACI_PNI)), ADDR_A.ACI_PNI, Set.of(rec(1, ADDR_A.ACI_PNI))),
new T(Set.of(rec(1, ADDR_A.FULL)), ADDR_B.FULL, Set.of(rec(1, ADDR_A.FULL), rec(1000000, ADDR_B.FULL))),
};
static T[] testInstancesTwo = new T[]{
// 1
new T(Set.of(rec(1, ADDR_A.ACI), rec(2, ADDR_A.PNI)), ADDR_A.FULL, Set.of(rec(1, ADDR_A.FULL))),
new T(Set.of(rec(1, ADDR_A.ACI), rec(2, ADDR_A.PNI_S)), ADDR_A.FULL, Set.of(rec(1, ADDR_A.FULL))),
new T(Set.of(rec(1, ADDR_A.ACI), rec(2, ADDR_A.NUM)), ADDR_A.FULL, Set.of(rec(1, ADDR_A.FULL))),
new T(Set.of(rec(1, ADDR_A.ACI), rec(2, ADDR_A.PNI_NUM)), ADDR_A.FULL, Set.of(rec(1, ADDR_A.FULL))),
new T(Set.of(rec(1, ADDR_A.ACI), rec(2, ADDR_A.PNI_S_NUM)), ADDR_A.FULL, Set.of(rec(1, ADDR_A.FULL))),
new T(Set.of(rec(1, ADDR_A.PNI), rec(2, ADDR_A.NUM)), ADDR_A.FULL, Set.of(rec(1, ADDR_A.FULL))),
new T(Set.of(rec(1, ADDR_A.PNI), rec(2, ADDR_A.ACI_NUM)), ADDR_A.FULL, Set.of(rec(2, ADDR_A.FULL))),
new T(Set.of(rec(1, ADDR_A.PNI_S), rec(2, ADDR_A.NUM)), ADDR_A.FULL, Set.of(rec(1, ADDR_A.FULL))),
new T(Set.of(rec(1, ADDR_A.PNI_S), rec(2, ADDR_A.ACI_NUM)), ADDR_A.FULL, Set.of(rec(2, ADDR_A.FULL))),
new T(Set.of(rec(1, ADDR_A.NUM), rec(2, ADDR_A.PNI_S)), ADDR_A.FULL, Set.of(rec(2, ADDR_A.FULL))),
new T(Set.of(rec(1, ADDR_A.NUM), rec(2, ADDR_A.ACI_PNI)), ADDR_A.FULL, Set.of(rec(2, ADDR_A.FULL))),
// 12
new T(Set.of(rec(1, ADDR_A.ACI), rec(2, ADDR_A.NUM)), ADDR_A.ACI_NUM, Set.of(rec(1, ADDR_A.ACI_NUM))),
new T(Set.of(rec(1, ADDR_A.ACI), rec(2, ADDR_A.PNI_NUM)), ADDR_A.ACI_NUM, Set.of(rec(1, ADDR_A.FULL))),
new T(Set.of(rec(1, ADDR_A.ACI), rec(2, ADDR_A.PNI_S_NUM)),
ADDR_A.ACI_NUM,
Set.of(rec(1, ADDR_A.ACI_NUM), rec(2, ADDR_A.PNI_S))),
new T(Set.of(rec(1, ADDR_A.NUM), rec(2, ADDR_A.ACI_PNI)), ADDR_A.ACI_NUM, Set.of(rec(2, ADDR_A.FULL))),
// 16
new T(Set.of(rec(1, ADDR_A.ACI), rec(2, ADDR_A.PNI_NUM)),
ADDR_A.PNI_NUM,
Set.of(rec(1, ADDR_A.ACI), rec(2, ADDR_A.PNI_NUM))),
new T(Set.of(rec(1, ADDR_A.ACI), rec(2, ADDR_A.PNI_S_NUM)),
ADDR_A.PNI_NUM,
Set.of(rec(1, ADDR_A.ACI), rec(2, ADDR_A.PNI_NUM))),
new T(Set.of(rec(1, ADDR_A.PNI), rec(2, ADDR_A.NUM)), ADDR_A.PNI_NUM, Set.of(rec(1, ADDR_A.PNI_NUM))),
new T(Set.of(rec(1, ADDR_A.PNI), rec(2, ADDR_A.ACI_NUM)),
ADDR_A.PNI_NUM,
Set.of(rec(1, ADDR_A.PNI_NUM), rec(2, ADDR_A.ACI))),
new T(Set.of(rec(1, ADDR_A.PNI_S), rec(2, ADDR_A.NUM)), ADDR_A.PNI_NUM, Set.of(rec(1, ADDR_A.PNI_NUM))),
new T(Set.of(rec(1, ADDR_A.PNI_S), rec(2, ADDR_A.ACI_NUM)),
ADDR_A.PNI_NUM,
Set.of(rec(1, ADDR_A.PNI_NUM), rec(2, ADDR_A.ACI))),
new T(Set.of(rec(1, ADDR_A.NUM), rec(2, ADDR_A.PNI_S)), ADDR_A.PNI_NUM, Set.of(rec(2, ADDR_A.PNI_NUM))),
new T(Set.of(rec(1, ADDR_A.NUM), rec(2, ADDR_A.ACI_PNI)), ADDR_A.PNI_NUM, Set.of(rec(2, ADDR_A.FULL))),
// 24
new T(Set.of(rec(1, ADDR_A.ACI), rec(2, ADDR_A.PNI_NUM)),
ADDR_A.PNI_S_NUM,
Set.of(rec(1, ADDR_A.ACI), rec(2, ADDR_A.PNI_NUM))),
new T(Set.of(rec(1, ADDR_A.ACI), rec(2, ADDR_A.PNI_S_NUM)),
ADDR_A.PNI_S_NUM,
Set.of(rec(1, ADDR_A.ACI), rec(2, ADDR_A.PNI_S_NUM))),
new T(Set.of(rec(1, ADDR_A.PNI), rec(2, ADDR_A.NUM)), ADDR_A.PNI_S_NUM, Set.of(rec(1, ADDR_A.PNI_NUM))),
new T(Set.of(rec(1, ADDR_A.PNI), rec(2, ADDR_A.ACI_NUM)),
ADDR_A.PNI_S_NUM,
Set.of(rec(1, ADDR_A.PNI_NUM), rec(2, ADDR_A.ACI))),
new T(Set.of(rec(1, ADDR_A.PNI_S), rec(2, ADDR_A.NUM)), ADDR_A.PNI_S_NUM, Set.of(rec(1, ADDR_A.PNI_S_NUM))),
new T(Set.of(rec(1, ADDR_A.PNI_S), rec(2, ADDR_A.ACI_NUM)),
ADDR_A.PNI_S_NUM,
Set.of(rec(1, ADDR_A.PNI_S_NUM), rec(2, ADDR_A.ACI))),
new T(Set.of(rec(1, ADDR_A.NUM), rec(2, ADDR_A.PNI_S)), ADDR_A.PNI_S_NUM, Set.of(rec(2, ADDR_A.PNI_S_NUM))),
new T(Set.of(rec(1, ADDR_A.NUM), rec(2, ADDR_A.ACI_PNI)), ADDR_A.PNI_S_NUM, Set.of(rec(2, ADDR_A.FULL))),
// 32
new T(Set.of(rec(1, ADDR_A.ACI), rec(2, ADDR_A.PNI)), ADDR_A.ACI_PNI, Set.of(rec(1, ADDR_A.ACI_PNI))),
new T(Set.of(rec(1, ADDR_A.ACI), rec(2, ADDR_A.PNI_S)), ADDR_A.ACI_PNI, Set.of(rec(1, ADDR_A.ACI_PNI))),
new T(Set.of(rec(1, ADDR_A.ACI), rec(2, ADDR_A.PNI_NUM)), ADDR_A.ACI_PNI, Set.of(rec(1, ADDR_A.FULL))),
new T(Set.of(rec(1, ADDR_A.ACI), rec(2, ADDR_A.PNI_S_NUM)),
ADDR_A.ACI_PNI,
Set.of(rec(1, ADDR_A.ACI_PNI), rec(2, ADDR_A.NUM))),
new T(Set.of(rec(1, ADDR_A.PNI), rec(2, ADDR_A.ACI_NUM)), ADDR_A.ACI_PNI, Set.of(rec(2, ADDR_A.FULL))),
new T(Set.of(rec(1, ADDR_A.PNI_S), rec(2, ADDR_A.ACI_NUM)), ADDR_A.ACI_PNI, Set.of(rec(2, ADDR_A.FULL))),
};
static T[] testInstancesThree = new T[]{
// 1
new T(Set.of(rec(1, ADDR_A.ACI), rec(2, ADDR_A.PNI), rec(3, ADDR_A.NUM)),
ADDR_A.FULL,
Set.of(rec(1, ADDR_A.FULL))),
new T(Set.of(rec(1, ADDR_A.ACI.withIdentifiersFrom(ADDR_B.PNI)), rec(2, ADDR_A.PNI), rec(3, ADDR_A.NUM)),
ADDR_A.FULL,
Set.of(rec(1, ADDR_A.FULL))),
new T(Set.of(rec(1, ADDR_A.ACI.withIdentifiersFrom(ADDR_B.NUM)), rec(2, ADDR_A.PNI), rec(3, ADDR_A.NUM)),
ADDR_A.FULL,
Set.of(rec(1, ADDR_A.FULL))),
new T(Set.of(rec(1, ADDR_A.ACI), rec(2, ADDR_A.PNI), rec(3, ADDR_A.NUM.withIdentifiersFrom(ADDR_B.ACI))),
ADDR_A.FULL,
Set.of(rec(1, ADDR_A.FULL), rec(3, ADDR_B.ACI))),
new T(Set.of(rec(1, ADDR_A.ACI), rec(2, ADDR_A.PNI.withIdentifiersFrom(ADDR_B.ACI)), rec(3, ADDR_A.NUM)),
ADDR_A.FULL,
Set.of(rec(1, ADDR_A.FULL), rec(2, ADDR_B.ACI))),
};
@ParameterizedTest
@MethodSource
void resolveRecipientTrustedLocked_NoneExisting(T test) throws Exception {
final var testStore = new TestStore(test.input);
MergeRecipientHelper.resolveRecipientTrustedLocked(testStore, test.request);
assertEquals(test.output, testStore.getRecipients());
}
private static Stream<Arguments> resolveRecipientTrustedLocked_NoneExisting() {
return Arrays.stream(testInstancesNone).map(Arguments::of);
}
@ParameterizedTest
@MethodSource
void resolveRecipientTrustedLocked_SingleExisting(T test) throws Exception {
final var testStore = new TestStore(test.input);
MergeRecipientHelper.resolveRecipientTrustedLocked(testStore, test.request);
assertEquals(test.output, testStore.getRecipients());
}
private static Stream<Arguments> resolveRecipientTrustedLocked_SingleExisting() {
return Arrays.stream(testInstancesSingle).map(Arguments::of);
}
@ParameterizedTest
@MethodSource
void resolveRecipientTrustedLocked_TwoExisting(T test) throws Exception {
final var testStore = new TestStore(test.input);
MergeRecipientHelper.resolveRecipientTrustedLocked(testStore, test.request);
assertEquals(test.output, testStore.getRecipients());
}
private static Stream<Arguments> resolveRecipientTrustedLocked_TwoExisting() {
return Arrays.stream(testInstancesTwo).map(Arguments::of);
}
@ParameterizedTest
@MethodSource
void resolveRecipientTrustedLocked_ThreeExisting(T test) throws Exception {
final var testStore = new TestStore(test.input);
MergeRecipientHelper.resolveRecipientTrustedLocked(testStore, test.request);
assertEquals(test.output, testStore.getRecipients());
}
private static Stream<Arguments> resolveRecipientTrustedLocked_ThreeExisting() {
return Arrays.stream(testInstancesThree).map(Arguments::of);
}
private static RecipientWithAddress rec(long recipientId, RecipientAddress address) {
return new RecipientWithAddress(new RecipientId(recipientId, null), address);
}
record T(
Set<RecipientWithAddress> input, RecipientAddress request, Set<RecipientWithAddress> output
) {
@Override
public String toString() {
return "T{#input=%s, request=%s_%s_%s, #output=%s}".formatted(input.size(),
request.serviceId().isPresent() ? "SVI" : "",
request.pni().isPresent() ? "PNI" : "",
request.number().isPresent() ? "NUM" : "",
output.size());
}
}
static class TestStore implements MergeRecipientHelper.Store {
final Set<RecipientWithAddress> recipients;
long nextRecipientId = 1000000;
TestStore(final Set<RecipientWithAddress> recipients) {
this.recipients = new HashSet<>(recipients);
}
public Set<RecipientWithAddress> getRecipients() {
return recipients;
}
@Override
public Set<RecipientWithAddress> findAllByAddress(final RecipientAddress address) {
return recipients.stream().filter(r -> r.address().matches(address)).collect(Collectors.toSet());
}
@Override
public RecipientId addNewRecipient(final RecipientAddress address) {
final var recipientId = new RecipientId(nextRecipientId++, null);
recipients.add(new RecipientWithAddress(recipientId, address));
return recipientId;
}
@Override
public void updateRecipientAddress(
final RecipientId recipientId, final RecipientAddress address
) {
recipients.removeIf(r -> r.id().equals(recipientId));
recipients.add(new RecipientWithAddress(recipientId, address));
}
@Override
public void removeRecipientAddress(final RecipientId recipientId) {
recipients.removeIf(r -> r.id().equals(recipientId));
}
}
private record PartialAddresses(
RecipientAddress FULL,
RecipientAddress ACI,
RecipientAddress PNI,
RecipientAddress PNI_S,
RecipientAddress NUM,
RecipientAddress ACI_NUM,
RecipientAddress PNI_NUM,
RecipientAddress PNI_S_NUM,
RecipientAddress ACI_PNI
) {
PartialAddresses(ServiceId serviceId, PNI pni, String number) {
this(new RecipientAddress(serviceId, pni, number),
new RecipientAddress(serviceId, null, null),
new RecipientAddress(null, pni, null),
new RecipientAddress(ServiceId.from(pni.uuid()), null, null),
new RecipientAddress(null, null, number),
new RecipientAddress(serviceId, null, number),
new RecipientAddress(null, pni, number),
new RecipientAddress(ServiceId.from(pni.uuid()), null, number),
new RecipientAddress(serviceId, pni, null));
}
}
}