Refactor check for registered users

This commit is contained in:
AsamK 2022-10-08 14:06:36 +02:00
parent 7eb7ee44f2
commit f2b334b57a
5 changed files with 77 additions and 33 deletions

View file

@ -217,13 +217,14 @@ class ManagerImpl implements Manager {
return numbers.stream().collect(Collectors.toMap(n -> n, n -> { return numbers.stream().collect(Collectors.toMap(n -> n, n -> {
final var number = canonicalizedNumbers.get(n); final var number = canonicalizedNumbers.get(n);
final var aci = registeredUsers.get(number); final var user = registeredUsers.get(number);
final var profile = aci == null final var serviceId = user == null ? null : user.getServiceId();
final var profile = serviceId == null
? null ? null
: context.getProfileHelper() : context.getProfileHelper()
.getRecipientProfile(account.getRecipientResolver().resolveRecipient(aci)); .getRecipientProfile(account.getRecipientResolver().resolveRecipient(serviceId));
return new UserStatus(number.isEmpty() ? null : number, return new UserStatus(number.isEmpty() ? null : number,
aci == null ? null : aci.uuid(), serviceId == null ? null : serviceId.uuid(),
profile != null profile != null
&& profile.getUnidentifiedAccessMode() == Profile.UnidentifiedAccessMode.UNRESTRICTED); && profile.getUnidentifiedAccessMode() == Profile.UnidentifiedAccessMode.UNRESTRICTED);
})); }));

View file

@ -11,6 +11,7 @@ import org.signal.libsignal.protocol.InvalidKeyException;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.whispersystems.signalservice.api.push.ACI; import org.whispersystems.signalservice.api.push.ACI;
import org.whispersystems.signalservice.api.push.PNI;
import org.whispersystems.signalservice.api.push.ServiceId; import org.whispersystems.signalservice.api.push.ServiceId;
import org.whispersystems.signalservice.api.push.SignalServiceAddress; import org.whispersystems.signalservice.api.push.SignalServiceAddress;
import org.whispersystems.signalservice.api.services.CdsiV2Service; import org.whispersystems.signalservice.api.services.CdsiV2Service;
@ -50,9 +51,9 @@ public class RecipientHelper {
// Address in recipient store doesn't have a uuid, this shouldn't happen // Address in recipient store doesn't have a uuid, this shouldn't happen
// Try to retrieve the uuid from the server // Try to retrieve the uuid from the server
final var number = address.number().get(); final var number = address.number().get();
final ACI aci; final ServiceId serviceId;
try { try {
aci = getRegisteredUser(number); serviceId = getRegisteredUser(number);
} catch (UnregisteredRecipientException e) { } catch (UnregisteredRecipientException e) {
logger.warn("Failed to get uuid for e164 number: {}", number); logger.warn("Failed to get uuid for e164 number: {}", number);
// Return SignalServiceAddress with unknown UUID // Return SignalServiceAddress with unknown UUID
@ -63,7 +64,7 @@ public class RecipientHelper {
return address.toSignalServiceAddress(); return address.toSignalServiceAddress();
} }
return account.getRecipientAddressResolver() return account.getRecipientAddressResolver()
.resolveRecipientAddress(account.getRecipientResolver().resolveRecipient(aci)) .resolveRecipientAddress(account.getRecipientResolver().resolveRecipient(serviceId))
.toSignalServiceAddress(); .toSignalServiceAddress();
} }
@ -101,12 +102,13 @@ public class RecipientHelper {
return recipientId; return recipientId;
} }
final var number = address.getNumber().get(); final var number = address.getNumber().get();
final var uuid = getRegisteredUser(number); final var serviceId = getRegisteredUser(number);
return account.getRecipientTrustedResolver().resolveRecipientTrusted(new SignalServiceAddress(uuid, number)); return account.getRecipientTrustedResolver()
.resolveRecipientTrusted(new SignalServiceAddress(serviceId, number));
} }
public Map<String, ACI> getRegisteredUsers(final Set<String> numbers) throws IOException { public Map<String, RegisteredUser> getRegisteredUsers(final Set<String> numbers) throws IOException {
Map<String, ACI> registeredUsers; Map<String, RegisteredUser> registeredUsers;
try { try {
registeredUsers = getRegisteredUsersV2(numbers, true); registeredUsers = getRegisteredUsersV2(numbers, true);
} catch (IOException e) { } catch (IOException e) {
@ -115,30 +117,30 @@ public class RecipientHelper {
} }
// Store numbers as recipients, so we have the number/uuid association // Store numbers as recipients, so we have the number/uuid association
registeredUsers.forEach((number, aci) -> account.getRecipientTrustedResolver() registeredUsers.forEach((number, u) -> account.getRecipientTrustedResolver()
.resolveRecipientTrusted(new SignalServiceAddress(aci, number))); .resolveRecipientTrusted(u.aci, u.pni, Optional.of(number)));
return registeredUsers; return registeredUsers;
} }
private ACI getRegisteredUser(final String number) throws IOException, UnregisteredRecipientException { private ServiceId getRegisteredUser(final String number) throws IOException, UnregisteredRecipientException {
final Map<String, ACI> aciMap; final Map<String, RegisteredUser> aciMap;
try { try {
aciMap = getRegisteredUsers(Set.of(number)); aciMap = getRegisteredUsers(Set.of(number));
} catch (NumberFormatException e) { } catch (NumberFormatException e) {
throw new UnregisteredRecipientException(new org.asamk.signal.manager.api.RecipientAddress(null, number)); throw new UnregisteredRecipientException(new org.asamk.signal.manager.api.RecipientAddress(null, number));
} }
final var uuid = aciMap.get(number); final var user = aciMap.get(number);
if (uuid == null) { if (user == null) {
throw new UnregisteredRecipientException(new org.asamk.signal.manager.api.RecipientAddress(null, number)); throw new UnregisteredRecipientException(new org.asamk.signal.manager.api.RecipientAddress(null, number));
} }
return uuid; return user.getServiceId();
} }
private Map<String, ACI> getRegisteredUsersV1(final Set<String> numbers) throws IOException { private Map<String, RegisteredUser> getRegisteredUsersV1(final Set<String> numbers) throws IOException {
final Map<String, ACI> registeredUsers; final Map<String, ACI> response;
try { try {
registeredUsers = dependencies.getAccountManager() response = dependencies.getAccountManager()
.getRegisteredUsers(ServiceConfig.getIasKeyStore(), .getRegisteredUsers(ServiceConfig.getIasKeyStore(),
numbers, numbers,
serviceEnvironmentConfig.getCdsMrenclave()); serviceEnvironmentConfig.getCdsMrenclave());
@ -146,10 +148,15 @@ public class RecipientHelper {
UnauthenticatedResponseException | InvalidKeyException | NumberFormatException e) { UnauthenticatedResponseException | InvalidKeyException | NumberFormatException e) {
throw new IOException(e); throw new IOException(e);
} }
final var registeredUsers = new HashMap<String, RegisteredUser>();
response.forEach((key, value) -> registeredUsers.put(key,
new RegisteredUser(Optional.of(value), Optional.empty())));
return registeredUsers; return registeredUsers;
} }
private Map<String, ACI> getRegisteredUsersV2(final Set<String> numbers, boolean useCompat) throws IOException { private Map<String, RegisteredUser> getRegisteredUsersV2(
final Set<String> numbers, boolean useCompat
) throws IOException {
// Only partial refresh is implemented here // Only partial refresh is implemented here
final CdsiV2Service.Response response; final CdsiV2Service.Response response;
try { try {
@ -168,16 +175,29 @@ public class RecipientHelper {
} }
logger.debug("CDSI request successful, quota used by this request: {}", response.getQuotaUsedDebugOnly()); logger.debug("CDSI request successful, quota used by this request: {}", response.getQuotaUsedDebugOnly());
final var registeredUsers = new HashMap<String, ACI>(); final var registeredUsers = new HashMap<String, RegisteredUser>();
response.getResults().forEach((key, value) -> { response.getResults()
if (value.getAci().isPresent()) { .forEach((key, value) -> registeredUsers.put(key,
registeredUsers.put(key, value.getAci().get()); new RegisteredUser(value.getAci(), Optional.of(value.getPni()))));
}
});
return registeredUsers; return registeredUsers;
} }
private ACI getRegisteredUserByUsername(String username) throws IOException { private ACI getRegisteredUserByUsername(String username) throws IOException {
return dependencies.getAccountManager().getAciByUsername(username); return dependencies.getAccountManager().getAciByUsername(username);
} }
public record RegisteredUser(Optional<ACI> aci, Optional<PNI> pni) {
public RegisteredUser {
aci = aci.isPresent() && aci.get().equals(ServiceId.UNKNOWN) ? Optional.empty() : aci;
pni = pni.isPresent() && pni.get().equals(ServiceId.UNKNOWN) ? Optional.empty() : pni;
if (aci.isEmpty() && pni.isEmpty()) {
throw new AssertionError("Must have either a ACI or PNI!");
}
}
public ServiceId getServiceId() {
return aci.map(a -> (ServiceId) a).or(this::pni).orElse(null);
}
}
} }

View file

@ -1200,6 +1200,13 @@ public class SignalAccount implements Closeable {
public RecipientId resolveRecipientTrusted(final SignalServiceAddress address) { public RecipientId resolveRecipientTrusted(final SignalServiceAddress address) {
return getRecipientStore().resolveRecipientTrusted(address); return getRecipientStore().resolveRecipientTrusted(address);
} }
@Override
public RecipientId resolveRecipientTrusted(
final Optional<ACI> aci, final Optional<PNI> pni, final Optional<String> number
) {
return getRecipientStore().resolveRecipientTrusted(aci, pni, number);
}
}; };
} }

View file

@ -12,6 +12,7 @@ import org.signal.libsignal.zkgroup.profiles.ProfileKey;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.whispersystems.signalservice.api.push.ACI; import org.whispersystems.signalservice.api.push.ACI;
import org.whispersystems.signalservice.api.push.PNI;
import org.whispersystems.signalservice.api.push.ServiceId; import org.whispersystems.signalservice.api.push.ServiceId;
import org.whispersystems.signalservice.api.push.SignalServiceAddress; import org.whispersystems.signalservice.api.push.SignalServiceAddress;
import org.whispersystems.signalservice.api.util.UuidUtil; import org.whispersystems.signalservice.api.util.UuidUtil;
@ -154,7 +155,7 @@ public class RecipientStore implements RecipientIdCreator, RecipientResolver, Re
} }
public RecipientId resolveRecipient( public RecipientId resolveRecipient(
final String number, Supplier<ACI> aciSupplier final String number, Supplier<ServiceId> serviceIdSupplier
) throws UnregisteredRecipientException { ) throws UnregisteredRecipientException {
final Optional<RecipientWithAddress> byNumber; final Optional<RecipientWithAddress> byNumber;
try (final var connection = database.getConnection()) { try (final var connection = database.getConnection()) {
@ -163,12 +164,13 @@ public class RecipientStore implements RecipientIdCreator, RecipientResolver, Re
throw new RuntimeException("Failed read from recipient store", e); throw new RuntimeException("Failed read from recipient store", e);
} }
if (byNumber.isEmpty() || byNumber.get().address().serviceId().isEmpty()) { if (byNumber.isEmpty() || byNumber.get().address().serviceId().isEmpty()) {
final var aci = aciSupplier.get(); final var serviceId = serviceIdSupplier.get();
if (aci == null) { if (serviceId == null) {
throw new UnregisteredRecipientException(new org.asamk.signal.manager.api.RecipientAddress(null, number)); throw new UnregisteredRecipientException(new org.asamk.signal.manager.api.RecipientAddress(null,
number));
} }
return resolveRecipient(new RecipientAddress(aci, number), false, false); return resolveRecipient(new RecipientAddress(serviceId, number), false, false);
} }
return byNumber.get().id(); return byNumber.get().id();
} }
@ -191,6 +193,14 @@ public class RecipientStore implements RecipientIdCreator, RecipientResolver, Re
return resolveRecipient(new RecipientAddress(address), true, false); return resolveRecipient(new RecipientAddress(address), true, false);
} }
@Override
public RecipientId resolveRecipientTrusted(
final Optional<ACI> aci, final Optional<PNI> pni, final Optional<String> number
) {
final var serviceId = aci.map(a -> (ServiceId) a).or(() -> pni);
return resolveRecipient(new RecipientAddress(serviceId, number), true, false);
}
@Override @Override
public void storeContact(RecipientId recipientId, final Contact contact) { public void storeContact(RecipientId recipientId, final Contact contact) {
try (final var connection = database.getConnection()) { try (final var connection = database.getConnection()) {

View file

@ -1,10 +1,16 @@
package org.asamk.signal.manager.storage.recipients; package org.asamk.signal.manager.storage.recipients;
import org.whispersystems.signalservice.api.push.ACI;
import org.whispersystems.signalservice.api.push.PNI;
import org.whispersystems.signalservice.api.push.SignalServiceAddress; import org.whispersystems.signalservice.api.push.SignalServiceAddress;
import java.util.Optional;
public interface RecipientTrustedResolver { public interface RecipientTrustedResolver {
RecipientId resolveSelfRecipientTrusted(RecipientAddress address); RecipientId resolveSelfRecipientTrusted(RecipientAddress address);
RecipientId resolveRecipientTrusted(SignalServiceAddress address); RecipientId resolveRecipientTrusted(SignalServiceAddress address);
RecipientId resolveRecipientTrusted(Optional<ACI> aci, Optional<PNI> pni, Optional<String> number);
} }