Implement full CDSI refresh

This commit is contained in:
AsamK 2023-10-15 22:36:45 +02:00
parent 7cd24a74af
commit 5c39344cff
6 changed files with 278 additions and 11 deletions

View file

@ -14,6 +14,7 @@ import org.whispersystems.signalservice.api.push.ServiceId;
import org.whispersystems.signalservice.api.push.ServiceId.ACI;
import org.whispersystems.signalservice.api.push.ServiceId.PNI;
import org.whispersystems.signalservice.api.push.SignalServiceAddress;
import org.whispersystems.signalservice.api.push.exceptions.CdsiInvalidTokenException;
import org.whispersystems.signalservice.api.services.CdsiV2Service;
import org.whispersystems.util.Base64UrlSafe;
@ -115,6 +116,10 @@ public class RecipientHelper {
}
}
public void refreshUsers() throws IOException {
getRegisteredUsers(account.getRecipientStore().getAllNumbers(), false);
}
public RecipientId refreshRegisteredUser(RecipientId recipientId) throws IOException, UnregisteredRecipientException {
final var address = resolveSignalServiceAddress(recipientId);
if (address.getNumber().isEmpty()) {
@ -126,8 +131,16 @@ public class RecipientHelper {
.resolveRecipientTrusted(new SignalServiceAddress(serviceId, number));
}
public Map<String, RegisteredUser> getRegisteredUsers(final Set<String> numbers) throws IOException {
Map<String, RegisteredUser> registeredUsers = getRegisteredUsersV2(numbers, true);
public Map<String, RegisteredUser> getRegisteredUsers(
final Set<String> numbers
) throws IOException {
return getRegisteredUsers(numbers, true);
}
private Map<String, RegisteredUser> getRegisteredUsers(
final Set<String> numbers, final boolean isPartialRefresh
) throws IOException {
Map<String, RegisteredUser> registeredUsers = getRegisteredUsersV2(numbers, isPartialRefresh, true);
// Store numbers as recipients, so we have the number/uuid association
registeredUsers.forEach((number, u) -> account.getRecipientTrustedResolver()
@ -139,7 +152,7 @@ public class RecipientHelper {
private ServiceId getRegisteredUserByNumber(final String number) throws IOException, UnregisteredRecipientException {
final Map<String, RegisteredUser> aciMap;
try {
aciMap = getRegisteredUsers(Set.of(number));
aciMap = getRegisteredUsers(Set.of(number), true);
} catch (NumberFormatException e) {
throw new UnregisteredRecipientException(new org.asamk.signal.manager.api.RecipientAddress(null, number));
}
@ -151,22 +164,50 @@ public class RecipientHelper {
}
private Map<String, RegisteredUser> getRegisteredUsersV2(
final Set<String> numbers, boolean useCompat
final Set<String> numbers, boolean isPartialRefresh, boolean useCompat
) throws IOException {
// Only partial refresh is implemented here
final var previousNumbers = isPartialRefresh ? Set.<String>of() : account.getCdsiStore().getAllNumbers();
final var newNumbers = new HashSet<>(numbers) {{
removeAll(previousNumbers);
}};
if (newNumbers.isEmpty() && previousNumbers.isEmpty()) {
logger.debug("No new numbers to query.");
return Map.of();
}
logger.trace("Querying CDSI for {} new numbers ({} previous)", newNumbers.size(), previousNumbers.size());
final var token = previousNumbers.isEmpty()
? Optional.<byte[]>empty()
: Optional.ofNullable(account.getCdsiToken());
final CdsiV2Service.Response response;
try {
response = dependencies.getAccountManager()
.getRegisteredUsersWithCdsi(Set.of(),
numbers,
.getRegisteredUsersWithCdsi(previousNumbers,
newNumbers,
account.getRecipientStore().getServiceIdToProfileKeyMap(),
useCompat,
Optional.empty(),
token,
serviceEnvironmentConfig.cdsiMrenclave(),
null,
token -> {
// Not storing for partial refresh
newToken -> {
if (isPartialRefresh) {
account.getCdsiStore().updateAfterPartialCdsQuery(newNumbers);
// Not storing newToken for partial refresh
} else {
final var fullNumbers = new HashSet<>(previousNumbers) {{
addAll(newNumbers);
}};
final var seenNumbers = new HashSet<>(numbers) {{
addAll(newNumbers);
}};
account.getCdsiStore().updateAfterFullCdsQuery(fullNumbers, seenNumbers);
account.setCdsiToken(newToken);
}
});
} catch (CdsiInvalidTokenException e) {
account.setCdsiToken(null);
account.getCdsiStore().clearAll();
throw e;
} catch (NumberFormatException e) {
throw new IOException(e);
}

View file

@ -9,6 +9,7 @@ import org.asamk.signal.manager.storage.keyValue.KeyValueStore;
import org.asamk.signal.manager.storage.prekeys.KyberPreKeyStore;
import org.asamk.signal.manager.storage.prekeys.PreKeyStore;
import org.asamk.signal.manager.storage.prekeys.SignedPreKeyStore;
import org.asamk.signal.manager.storage.recipients.CdsiStore;
import org.asamk.signal.manager.storage.recipients.RecipientStore;
import org.asamk.signal.manager.storage.sendLog.MessageSendLogStore;
import org.asamk.signal.manager.storage.senderKeys.SenderKeyRecordStore;
@ -31,7 +32,7 @@ import java.util.UUID;
public class AccountDatabase extends Database {
private final static Logger logger = LoggerFactory.getLogger(AccountDatabase.class);
private static final long DATABASE_VERSION = 17;
private static final long DATABASE_VERSION = 18;
private AccountDatabase(final HikariDataSource dataSource) {
super(logger, DATABASE_VERSION, dataSource);
@ -55,6 +56,7 @@ public class AccountDatabase extends Database {
SenderKeyRecordStore.createSql(connection);
SenderKeySharedStore.createSql(connection);
KeyValueStore.createSql(connection);
CdsiStore.createSql(connection);
}
@Override
@ -517,5 +519,17 @@ public class AccountDatabase extends Database {
""");
}
}
if (oldVersion < 18) {
logger.debug("Updating database: Adding cdsi table");
try (final var statement = connection.createStatement()) {
statement.executeUpdate("""
CREATE TABLE cdsi (
_id INTEGER PRIMARY KEY,
number TEXT NOT NULL UNIQUE,
last_seen_at INTEGER NOT NULL
) STRICT;
""");
}
}
}
}

View file

@ -33,6 +33,7 @@ import org.asamk.signal.manager.storage.profiles.LegacyProfileStore;
import org.asamk.signal.manager.storage.profiles.ProfileStore;
import org.asamk.signal.manager.storage.protocol.LegacyJsonSignalProtocolStore;
import org.asamk.signal.manager.storage.protocol.SignalProtocolStore;
import org.asamk.signal.manager.storage.recipients.CdsiStore;
import org.asamk.signal.manager.storage.recipients.LegacyRecipientStore;
import org.asamk.signal.manager.storage.recipients.LegacyRecipientStore2;
import org.asamk.signal.manager.storage.recipients.RecipientAddress;
@ -145,6 +146,7 @@ public class SignalAccount implements Closeable {
private final KeyValueEntry<Long> lastReceiveTimestamp = new KeyValueEntry<>("last-receive-timestamp",
long.class,
0L);
private final KeyValueEntry<byte[]> cdsiToken = new KeyValueEntry<>("cdsi-token", byte[].class);
private final KeyValueEntry<Long> storageManifestVersion = new KeyValueEntry<>("storage-manifest-version",
long.class,
-1L);
@ -160,6 +162,7 @@ public class SignalAccount implements Closeable {
private StickerStore stickerStore;
private ConfigurationStore configurationStore;
private KeyValueStore keyValueStore;
private CdsiStore cdsiStore;
private MessageCache messageCache;
private MessageSendLogStore messageSendLogStore;
@ -1220,6 +1223,10 @@ public class SignalAccount implements Closeable {
return getRecipientStore();
}
public CdsiStore getCdsiStore() {
return getOrCreate(() -> cdsiStore, () -> cdsiStore = new CdsiStore(getAccountDatabase()));
}
private RecipientIdCreator getRecipientIdCreator() {
return recipientId -> getRecipientStore().create(recipientId);
}
@ -1571,6 +1578,14 @@ public class SignalAccount implements Closeable {
}
}
public byte[] getCdsiToken() {
return getKeyValueStore().getEntry(cdsiToken);
}
public void setCdsiToken(final byte[] value) {
getKeyValueStore().storeEntry(cdsiToken, value);
}
public ProfileKey getProfileKey() {
return profileKey;
}

View file

@ -90,6 +90,8 @@ public class KeyValueStore {
value = resultSet.getLong("value");
} else if (clazz == boolean.class || clazz == Boolean.class) {
value = resultSet.getBoolean("value");
} else if (clazz == byte[].class || clazz == Byte[].class) {
value = resultSet.getBytes("value");
} else if (clazz == String.class) {
value = resultSet.getString("value");
} else if (Enum.class.isAssignableFrom(clazz)) {
@ -134,6 +136,12 @@ public class KeyValueStore {
} else {
statement.setBoolean(parameterIndex, (boolean) value);
}
} else if (clazz == byte[].class || clazz == Byte[].class) {
if (value == null) {
statement.setNull(parameterIndex, Types.BLOB);
} else {
statement.setBytes(parameterIndex, (byte[]) value);
}
} else if (clazz == String.class) {
if (value == null) {
statement.setNull(parameterIndex, Types.VARCHAR);

View file

@ -0,0 +1,170 @@
package org.asamk.signal.manager.storage.recipients;
import org.asamk.signal.manager.storage.Database;
import org.asamk.signal.manager.storage.Utils;
import java.sql.Connection;
import java.sql.SQLException;
import java.util.HashSet;
import java.util.Set;
import java.util.stream.Collectors;
public class CdsiStore {
private static final String TABLE_CDSI = "cdsi";
private final Database database;
public static void createSql(Connection connection) throws SQLException {
// When modifying the CREATE statement here, also add a migration in AccountDatabase.java
try (final var statement = connection.createStatement()) {
statement.executeUpdate("""
CREATE TABLE cdsi (
_id INTEGER PRIMARY KEY,
number TEXT NOT NULL UNIQUE,
last_seen_at INTEGER NOT NULL
) STRICT;
""");
}
}
public CdsiStore(final Database database) {
this.database = database;
}
public Set<String> getAllNumbers() {
try (final var connection = database.getConnection()) {
return getAllNumbers(connection);
} catch (SQLException e) {
throw new RuntimeException("Failed read from cdsi store", e);
}
}
/**
* Saves the set of e164 numbers used after a full refresh.
*
* @param fullNumbers All the e164 numbers used in the last CDS query (previous and new).
* @param seenNumbers The E164 numbers that were seen in either the system contacts or recipients table. This is different from fullNumbers in that fullNumbers
* includes every number we've ever seen, even if it's not in our contacts anymore.
*/
public void updateAfterFullCdsQuery(Set<String> fullNumbers, Set<String> seenNumbers) {
final var lastSeen = System.currentTimeMillis();
try (final var connection = database.getConnection()) {
final var existingNumbers = getAllNumbers(connection);
final var removedNumbers = new HashSet<>(existingNumbers) {{
removeAll(fullNumbers);
}};
removeNumbers(connection, removedNumbers);
final var addedNumbers = new HashSet<>(fullNumbers) {{
removeAll(existingNumbers);
}};
addNumbers(connection, addedNumbers, lastSeen);
updateLastSeen(connection, seenNumbers, lastSeen);
} catch (SQLException e) {
throw new RuntimeException("Failed update cdsi store", e);
}
}
/**
* Updates after a partial CDS query. Will not insert new entries.
* Instead, this will simply update the lastSeen timestamp of any entry we already have.
*
* @param seenNumbers The newly-added E164 numbers that we hadn't previously queried for.
*/
public void updateAfterPartialCdsQuery(Set<String> seenNumbers) {
final var lastSeen = System.currentTimeMillis();
try (final var connection = database.getConnection()) {
updateLastSeen(connection, seenNumbers, lastSeen);
} catch (SQLException e) {
throw new RuntimeException("Failed update cdsi store", e);
}
}
private static Set<String> getAllNumbers(final Connection connection) throws SQLException {
final var sql = (
"""
SELECT c.number
FROM %s c
"""
).formatted(TABLE_CDSI);
try (final var statement = connection.prepareStatement(sql)) {
try (var result = Utils.executeQueryForStream(statement, r -> r.getString("number"))) {
return result.collect(Collectors.toSet());
}
}
}
private static void removeNumbers(
final Connection connection, final Set<String> numbers
) throws SQLException {
final var sql = (
"""
DELETE FROM %s
WHERE number = ?
"""
).formatted(TABLE_CDSI);
try (final var statement = connection.prepareStatement(sql)) {
for (final var number : numbers) {
statement.setString(1, number);
statement.executeUpdate();
}
}
}
private static void addNumbers(
final Connection connection, final Set<String> numbers, final long lastSeen
) throws SQLException {
final var sql = (
"""
INSERT INTO %s (number, last_seen_at)
VALUES (?, ?)
ON CONFLICT (number) DO UPDATE SET last_seen_at = excluded.last_seen_at
"""
).formatted(TABLE_CDSI);
try (final var statement = connection.prepareStatement(sql)) {
for (final var number : numbers) {
statement.setString(1, number);
statement.setLong(2, lastSeen);
statement.executeUpdate();
}
}
}
private static void updateLastSeen(
final Connection connection, final Set<String> numbers, final long lastSeen
) throws SQLException {
final var sql = (
"""
UPDATE %s
SET last_seen_at = ?
WHERE number = ?
"""
).formatted(TABLE_CDSI);
try (final var statement = connection.prepareStatement(sql)) {
for (final var number : numbers) {
statement.setLong(1, lastSeen);
statement.setString(2, number);
statement.executeUpdate();
}
}
}
public void clearAll() {
final var sql = (
"""
TRUNCATE %s
"""
).formatted(TABLE_CDSI);
try (final var connection = database.getConnection()) {
try (final var statement = connection.prepareStatement(sql)) {
statement.executeUpdate();
}
} catch (SQLException e) {
throw new RuntimeException("Failed update cdsi store", e);
}
}
}

View file

@ -383,6 +383,25 @@ public class RecipientStore implements RecipientIdCreator, RecipientResolver, Re
}
}
public Set<String> getAllNumbers() {
final var sql = (
"""
SELECT r.number
FROM %s r
WHERE r.number IS NOT NULL
"""
).formatted(TABLE_RECIPIENT);
try (final var connection = database.getConnection()) {
try (final var statement = connection.prepareStatement(sql)) {
return Utils.executeQueryForStream(statement, resultSet -> resultSet.getString("number"))
.filter(Objects::nonNull)
.collect(Collectors.toSet());
}
} catch (SQLException e) {
throw new RuntimeException("Failed read from recipient store", e);
}
}
public Map<ServiceId, ProfileKey> getServiceIdToProfileKeyMap() {
final var sql = (
"""