Move sender key store to database

This commit is contained in:
AsamK 2022-06-10 23:21:39 +02:00
parent 0c4a037dde
commit 08dc65350f
8 changed files with 569 additions and 330 deletions

View file

@ -1168,16 +1168,18 @@
] ]
}, },
{ {
"name":"org.asamk.signal.manager.storage.senderKeys.SenderKeySharedStore$Storage", "name":"org.asamk.signal.manager.storage.senderKeys.LegacySenderKeySharedStore$Storage",
"allDeclaredFields":true, "allDeclaredFields":true,
"allDeclaredMethods":true, "queryAllDeclaredMethods":true,
"allDeclaredConstructors":true "queryAllDeclaredConstructors":true,
"methods":[{"name":"<init>","parameterTypes":["java.util.List"] }]
}, },
{ {
"name":"org.asamk.signal.manager.storage.senderKeys.SenderKeySharedStore$Storage$SharedSenderKey", "name":"org.asamk.signal.manager.storage.senderKeys.LegacySenderKeySharedStore$Storage$SharedSenderKey",
"allDeclaredFields":true, "allDeclaredFields":true,
"allDeclaredMethods":true, "queryAllDeclaredMethods":true,
"allDeclaredConstructors":true "queryAllDeclaredConstructors":true,
"methods":[{"name":"<init>","parameterTypes":["long","int","java.lang.String"] }]
}, },
{ {
"name":"org.asamk.signal.manager.storage.stickerPacks.JsonStickerPack", "name":"org.asamk.signal.manager.storage.stickerPacks.JsonStickerPack",

View file

@ -8,6 +8,8 @@ import org.asamk.signal.manager.storage.prekeys.PreKeyStore;
import org.asamk.signal.manager.storage.prekeys.SignedPreKeyStore; import org.asamk.signal.manager.storage.prekeys.SignedPreKeyStore;
import org.asamk.signal.manager.storage.recipients.RecipientStore; import org.asamk.signal.manager.storage.recipients.RecipientStore;
import org.asamk.signal.manager.storage.sendLog.MessageSendLogStore; import org.asamk.signal.manager.storage.sendLog.MessageSendLogStore;
import org.asamk.signal.manager.storage.senderKeys.SenderKeyRecordStore;
import org.asamk.signal.manager.storage.senderKeys.SenderKeySharedStore;
import org.asamk.signal.manager.storage.sessions.SessionStore; import org.asamk.signal.manager.storage.sessions.SessionStore;
import org.asamk.signal.manager.storage.stickers.StickerStore; import org.asamk.signal.manager.storage.stickers.StickerStore;
import org.slf4j.Logger; import org.slf4j.Logger;
@ -20,7 +22,7 @@ import java.sql.SQLException;
public class AccountDatabase extends Database { public class AccountDatabase extends Database {
private final static Logger logger = LoggerFactory.getLogger(AccountDatabase.class); private final static Logger logger = LoggerFactory.getLogger(AccountDatabase.class);
private static final long DATABASE_VERSION = 7; private static final long DATABASE_VERSION = 8;
private AccountDatabase(final HikariDataSource dataSource) { private AccountDatabase(final HikariDataSource dataSource) {
super(logger, DATABASE_VERSION, dataSource); super(logger, DATABASE_VERSION, dataSource);
@ -40,6 +42,8 @@ public class AccountDatabase extends Database {
GroupStore.createSql(connection); GroupStore.createSql(connection);
SessionStore.createSql(connection); SessionStore.createSql(connection);
IdentityKeyStore.createSql(connection); IdentityKeyStore.createSql(connection);
SenderKeyRecordStore.createSql(connection);
SenderKeySharedStore.createSql(connection);
} }
@Override @Override
@ -176,5 +180,29 @@ public class AccountDatabase extends Database {
"""); """);
} }
} }
if (oldVersion < 8) {
logger.debug("Updating database: Creating sender key tables");
try (final var statement = connection.createStatement()) {
statement.executeUpdate("""
CREATE TABLE sender_key (
_id INTEGER PRIMARY KEY,
recipient_id INTEGER NOT NULL REFERENCES recipient (_id) ON DELETE CASCADE,
device_id INTEGER NOT NULL,
distribution_id BLOB NOT NULL,
record BLOB NOT NULL,
created_timestamp INTEGER NOT NULL,
UNIQUE(recipient_id, device_id, distribution_id)
);
CREATE TABLE sender_key_shared (
_id INTEGER PRIMARY KEY,
recipient_id INTEGER NOT NULL REFERENCES recipient (_id) ON DELETE CASCADE,
device_id INTEGER NOT NULL,
distribution_id BLOB NOT NULL,
timestamp INTEGER NOT NULL,
UNIQUE(recipient_id, device_id, distribution_id)
);
""");
}
}
} }
} }

View file

@ -38,6 +38,8 @@ import org.asamk.signal.manager.storage.recipients.RecipientResolver;
import org.asamk.signal.manager.storage.recipients.RecipientStore; import org.asamk.signal.manager.storage.recipients.RecipientStore;
import org.asamk.signal.manager.storage.recipients.RecipientTrustedResolver; import org.asamk.signal.manager.storage.recipients.RecipientTrustedResolver;
import org.asamk.signal.manager.storage.sendLog.MessageSendLogStore; import org.asamk.signal.manager.storage.sendLog.MessageSendLogStore;
import org.asamk.signal.manager.storage.senderKeys.LegacySenderKeyRecordStore;
import org.asamk.signal.manager.storage.senderKeys.LegacySenderKeySharedStore;
import org.asamk.signal.manager.storage.senderKeys.SenderKeyStore; import org.asamk.signal.manager.storage.senderKeys.SenderKeyStore;
import org.asamk.signal.manager.storage.sessions.LegacySessionStore; import org.asamk.signal.manager.storage.sessions.LegacySessionStore;
import org.asamk.signal.manager.storage.sessions.SessionStore; import org.asamk.signal.manager.storage.sessions.SessionStore;
@ -668,6 +670,16 @@ public class SignalAccount implements Closeable {
migratedLegacyConfig = loadLegacyStores(rootNode, legacySignalProtocolStore) || migratedLegacyConfig; migratedLegacyConfig = loadLegacyStores(rootNode, legacySignalProtocolStore) || migratedLegacyConfig;
final var legacySenderKeysPath = getSenderKeysPath(dataPath, accountPath);
if (legacySenderKeysPath.exists()) {
LegacySenderKeyRecordStore.migrate(legacySenderKeysPath, getRecipientResolver(), getSenderKeyStore());
migratedLegacyConfig = true;
}
final var legacySenderKeysSharedPath = getSharedSenderKeysFile(dataPath, accountPath);
if (legacySenderKeysSharedPath.exists()) {
LegacySenderKeySharedStore.migrate(legacySenderKeysSharedPath, getRecipientResolver(), getSenderKeyStore());
migratedLegacyConfig = true;
}
if (rootNode.hasNonNull("groupStore")) { if (rootNode.hasNonNull("groupStore")) {
final var groupStoreStorage = jsonProcessor.convertValue(rootNode.get("groupStore"), final var groupStoreStorage = jsonProcessor.convertValue(rootNode.get("groupStore"),
LegacyGroupStore.Storage.class); LegacyGroupStore.Storage.class);
@ -1196,10 +1208,10 @@ public class SignalAccount implements Closeable {
public SenderKeyStore getSenderKeyStore() { public SenderKeyStore getSenderKeyStore() {
return getOrCreate(() -> senderKeyStore, return getOrCreate(() -> senderKeyStore,
() -> senderKeyStore = new SenderKeyStore(getSharedSenderKeysFile(dataPath, accountPath), () -> senderKeyStore = new SenderKeyStore(getAccountDatabase(),
getSenderKeysPath(dataPath, accountPath),
getRecipientAddressResolver(), getRecipientAddressResolver(),
getRecipientResolver())); getRecipientResolver(),
getRecipientIdCreator()));
} }
public ConfigurationStore getConfigurationStore() { public ConfigurationStore getConfigurationStore() {

View file

@ -0,0 +1,101 @@
package org.asamk.signal.manager.storage.senderKeys;
import org.asamk.signal.manager.api.Pair;
import org.asamk.signal.manager.storage.recipients.RecipientResolver;
import org.signal.libsignal.protocol.InvalidMessageException;
import org.signal.libsignal.protocol.groups.state.SenderKeyRecord;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.nio.file.Files;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.UUID;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import static org.asamk.signal.manager.storage.senderKeys.SenderKeyRecordStore.Key;
public class LegacySenderKeyRecordStore {
private final static Logger logger = LoggerFactory.getLogger(LegacySenderKeyRecordStore.class);
public static void migrate(
final File senderKeysPath, final RecipientResolver resolver, SenderKeyStore senderKeyStore
) {
final var files = senderKeysPath.listFiles();
if (files == null) {
return;
}
final var senderKeys = parseFileNames(files, resolver).stream().map(key -> {
final var record = loadSenderKeyLocked(key, senderKeysPath);
if (record == null) {
return null;
}
return new Pair<>(key, record);
}).filter(Objects::nonNull).toList();
senderKeyStore.addLegacySenderKeys(senderKeys);
deleteAllSenderKeys(senderKeysPath);
}
private static void deleteAllSenderKeys(File senderKeysPath) {
final var files = senderKeysPath.listFiles();
if (files == null) {
return;
}
for (var file : files) {
try {
Files.delete(file.toPath());
} catch (IOException e) {
logger.error("Failed to delete sender key file {}: {}", file, e.getMessage());
}
}
try {
Files.delete(senderKeysPath.toPath());
} catch (IOException e) {
logger.error("Failed to delete sender keys directory {}: {}", senderKeysPath, e.getMessage());
}
}
final static Pattern senderKeyFileNamePattern = Pattern.compile("(\\d+)_(\\d+)_([\\da-z\\-]+)");
private static List<Key> parseFileNames(final File[] files, final RecipientResolver resolver) {
return Arrays.stream(files)
.map(f -> senderKeyFileNamePattern.matcher(f.getName()))
.filter(Matcher::matches)
.map(matcher -> {
final var recipientId = resolver.resolveRecipient(Long.parseLong(matcher.group(1)));
if (recipientId == null) {
return null;
}
return new Key(recipientId, Integer.parseInt(matcher.group(2)), UUID.fromString(matcher.group(3)));
})
.filter(Objects::nonNull)
.toList();
}
private static File getSenderKeyFile(Key key, final File senderKeysPath) {
return new File(senderKeysPath,
key.recipientId().id() + "_" + key.deviceId() + "_" + key.distributionId().toString());
}
private static SenderKeyRecord loadSenderKeyLocked(final Key key, final File senderKeysPath) {
final var file = getSenderKeyFile(key, senderKeysPath);
if (!file.exists()) {
return null;
}
try (var inputStream = new FileInputStream(file)) {
return new SenderKeyRecord(inputStream.readAllBytes());
} catch (IOException | InvalidMessageException e) {
logger.warn("Failed to load sender key, resetting sender key: {}", e.getMessage());
return null;
}
}
}

View file

@ -0,0 +1,56 @@
package org.asamk.signal.manager.storage.senderKeys;
import org.asamk.signal.manager.storage.Utils;
import org.asamk.signal.manager.storage.recipients.RecipientResolver;
import org.asamk.signal.manager.storage.senderKeys.SenderKeySharedStore.SenderKeySharedEntry;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.signalservice.api.push.DistributionId;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.nio.file.Files;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
public class LegacySenderKeySharedStore {
private final static Logger logger = LoggerFactory.getLogger(LegacySenderKeySharedStore.class);
public static void migrate(
final File file, final RecipientResolver resolver, SenderKeyStore senderKeyStore
) {
final var objectMapper = Utils.createStorageObjectMapper();
try (var inputStream = new FileInputStream(file)) {
final var storage = objectMapper.readValue(inputStream, Storage.class);
final var sharedSenderKeys = new HashMap<DistributionId, Set<SenderKeySharedEntry>>();
for (final var senderKey : storage.sharedSenderKeys) {
final var recipientId = resolver.resolveRecipient(senderKey.recipientId);
if (recipientId == null) {
continue;
}
final var entry = new SenderKeySharedEntry(recipientId, senderKey.deviceId);
final var distributionId = DistributionId.from(senderKey.distributionId);
var entries = sharedSenderKeys.get(distributionId);
if (entries == null) {
entries = new HashSet<>();
}
entries.add(entry);
sharedSenderKeys.put(distributionId, entries);
}
senderKeyStore.addLegacySenderKeysShared(sharedSenderKeys);
Files.delete(file.toPath());
} catch (IOException e) {
logger.info("Failed to load shared sender key store, ignoring", e);
}
}
private record Storage(List<SharedSenderKey> sharedSenderKeys) {
private record SharedSenderKey(long recipientId, int deviceId, String distributionId) {}
}
}

View file

@ -1,43 +1,53 @@
package org.asamk.signal.manager.storage.senderKeys; package org.asamk.signal.manager.storage.senderKeys;
import org.asamk.signal.manager.api.Pair;
import org.asamk.signal.manager.storage.Database;
import org.asamk.signal.manager.storage.Utils;
import org.asamk.signal.manager.storage.recipients.RecipientId; import org.asamk.signal.manager.storage.recipients.RecipientId;
import org.asamk.signal.manager.storage.recipients.RecipientResolver; import org.asamk.signal.manager.storage.recipients.RecipientResolver;
import org.asamk.signal.manager.util.IOUtils;
import org.signal.libsignal.protocol.InvalidMessageException; import org.signal.libsignal.protocol.InvalidMessageException;
import org.signal.libsignal.protocol.SignalProtocolAddress; import org.signal.libsignal.protocol.SignalProtocolAddress;
import org.signal.libsignal.protocol.groups.state.SenderKeyRecord; import org.signal.libsignal.protocol.groups.state.SenderKeyRecord;
import org.signal.libsignal.protocol.groups.state.SenderKeyStore; import org.signal.libsignal.protocol.groups.state.SenderKeyStore;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.whispersystems.signalservice.api.util.UuidUtil;
import java.io.File; import java.sql.Connection;
import java.io.FileInputStream; import java.sql.ResultSet;
import java.io.FileOutputStream; import java.sql.SQLException;
import java.io.IOException; import java.util.Collection;
import java.nio.file.Files;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.UUID; import java.util.UUID;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
public class SenderKeyRecordStore implements SenderKeyStore { public class SenderKeyRecordStore implements SenderKeyStore {
private final static Logger logger = LoggerFactory.getLogger(SenderKeyRecordStore.class); private final static Logger logger = LoggerFactory.getLogger(SenderKeyRecordStore.class);
private final static String TABLE_SENDER_KEY = "sender_key";
private final Map<Key, SenderKeyRecord> cachedSenderKeys = new HashMap<>(); private final Database database;
private final File senderKeysPath;
private final RecipientResolver resolver; private final RecipientResolver resolver;
public SenderKeyRecordStore( public static void createSql(Connection connection) throws SQLException {
final File senderKeysPath, final RecipientResolver resolver // When modifying the CREATE statement here, also add a migration in AccountDatabase.java
try (final var statement = connection.createStatement()) {
statement.executeUpdate("""
CREATE TABLE sender_key (
_id INTEGER PRIMARY KEY,
recipient_id INTEGER NOT NULL REFERENCES recipient (_id) ON DELETE CASCADE,
device_id INTEGER NOT NULL,
distribution_id BLOB NOT NULL,
record BLOB NOT NULL,
created_timestamp INTEGER NOT NULL,
UNIQUE(recipient_id, device_id, distribution_id)
);
""");
}
}
SenderKeyRecordStore(
final Database database, final RecipientResolver resolver
) { ) {
this.senderKeysPath = senderKeysPath; this.database = database;
this.resolver = resolver; this.resolver = resolver;
} }
@ -45,8 +55,10 @@ public class SenderKeyRecordStore implements SenderKeyStore {
public SenderKeyRecord loadSenderKey(final SignalProtocolAddress address, final UUID distributionId) { public SenderKeyRecord loadSenderKey(final SignalProtocolAddress address, final UUID distributionId) {
final var key = getKey(address, distributionId); final var key = getKey(address, distributionId);
synchronized (cachedSenderKeys) { try (final var connection = database.getConnection()) {
return loadSenderKeyLocked(key); return loadSenderKey(connection, key);
} catch (SQLException e) {
throw new RuntimeException("Failed read from sender key store", e);
} }
} }
@ -56,88 +68,111 @@ public class SenderKeyRecordStore implements SenderKeyStore {
) { ) {
final var key = getKey(address, distributionId); final var key = getKey(address, distributionId);
synchronized (cachedSenderKeys) { try (final var connection = database.getConnection()) {
storeSenderKeyLocked(key, record); storeSenderKey(connection, key, record);
} catch (SQLException e) {
throw new RuntimeException("Failed update sender key store", e);
} }
} }
long getCreateTimeForKey(final RecipientId selfRecipientId, final int selfDeviceId, final UUID distributionId) { long getCreateTimeForKey(final RecipientId selfRecipientId, final int selfDeviceId, final UUID distributionId) {
final var key = getKey(selfRecipientId, selfDeviceId, distributionId); final var sql = (
final var senderKeyFile = getSenderKeyFile(key); """
SELECT s.created_timestamp
if (!senderKeyFile.exists()) { FROM %s AS s
return -1; WHERE s.recipient_id = ? AND s.device_id = ? AND s.distribution_id = ?
"""
).formatted(TABLE_SENDER_KEY);
try (final var connection = database.getConnection()) {
try (final var statement = connection.prepareStatement(sql)) {
statement.setLong(1, selfRecipientId.id());
statement.setInt(2, selfDeviceId);
statement.setBytes(3, UuidUtil.toByteArray(distributionId));
return Utils.executeQueryForOptional(statement, res -> res.getLong("created_timestamp")).orElse(-1L);
}
} catch (SQLException e) {
throw new RuntimeException("Failed read from sender key store", e);
} }
return IOUtils.getFileCreateTime(senderKeyFile);
} }
void deleteSenderKey(final RecipientId recipientId, final UUID distributionId) { void deleteSenderKey(final RecipientId recipientId, final UUID distributionId) {
synchronized (cachedSenderKeys) { final var sql = (
cachedSenderKeys.clear(); """
final var keys = getKeysLocked(recipientId); DELETE FROM %s AS s
for (var key : keys) { WHERE s.recipient_id = ? AND s.distribution_id = ?
if (key.distributionId.equals(distributionId)) { """
deleteSenderKeyLocked(key); ).formatted(TABLE_SENDER_KEY);
} try (final var connection = database.getConnection()) {
try (final var statement = connection.prepareStatement(sql)) {
statement.setLong(1, recipientId.id());
statement.setBytes(2, UuidUtil.toByteArray(distributionId));
statement.executeUpdate();
} }
} catch (SQLException e) {
throw new RuntimeException("Failed update sender key store", e);
} }
} }
void deleteAll() { void deleteAll() {
synchronized (cachedSenderKeys) { final var sql = """
cachedSenderKeys.clear(); DELETE FROM %s AS s
final var files = senderKeysPath.listFiles((_file, s) -> senderKeyFileNamePattern.matcher(s).matches()); """.formatted(TABLE_SENDER_KEY);
if (files == null) { try (final var connection = database.getConnection()) {
return; try (final var statement = connection.prepareStatement(sql)) {
} statement.executeUpdate();
for (final var file : files) {
try {
Files.delete(file.toPath());
} catch (IOException e) {
logger.error("Failed to delete sender key file {}: {}", file, e.getMessage());
}
} }
} catch (SQLException e) {
throw new RuntimeException("Failed update sender key store", e);
} }
} }
void deleteAllFor(final RecipientId recipientId) { void deleteAllFor(final RecipientId recipientId) {
synchronized (cachedSenderKeys) { try (final var connection = database.getConnection()) {
cachedSenderKeys.clear(); deleteAllFor(connection, recipientId);
final var keys = getKeysLocked(recipientId); } catch (SQLException e) {
for (var key : keys) { throw new RuntimeException("Failed update sender key store", e);
deleteSenderKeyLocked(key);
}
} }
} }
void mergeRecipients(RecipientId recipientId, RecipientId toBeMergedRecipientId) { void mergeRecipients(RecipientId recipientId, RecipientId toBeMergedRecipientId) {
synchronized (cachedSenderKeys) { try (final var connection = database.getConnection()) {
final var keys = getKeysLocked(toBeMergedRecipientId); connection.setAutoCommit(false);
final var otherHasSenderKeys = keys.size() > 0; final var sql = """
if (!otherHasSenderKeys) { UPDATE OR IGNORE %s
return; SET recipient_id = ?
} WHERE recipient_id = ?
""".formatted(TABLE_SENDER_KEY);
logger.debug("To be merged recipient had sender keys, re-assigning to the new recipient."); try (final var statement = connection.prepareStatement(sql)) {
for (var key : keys) { statement.setLong(1, recipientId.id());
final var toBeMergedSenderKey = loadSenderKeyLocked(key); statement.setLong(2, toBeMergedRecipientId.id());
deleteSenderKeyLocked(key); final var rows = statement.executeUpdate();
if (toBeMergedSenderKey == null) { if (rows > 0) {
continue; logger.debug("Reassigned {} sender keys of to be merged recipient.", rows);
} }
final var newKey = new Key(recipientId, key.deviceId(), key.distributionId);
final var senderKeyRecord = loadSenderKeyLocked(newKey);
if (senderKeyRecord != null) {
continue;
}
storeSenderKeyLocked(newKey, toBeMergedSenderKey);
} }
// Delete all conflicting sender keys now
deleteAllFor(connection, toBeMergedRecipientId);
connection.commit();
} catch (SQLException e) {
throw new RuntimeException("Failed update sender key store", e);
} }
} }
void addLegacySenderKeys(final Collection<Pair<Key, SenderKeyRecord>> senderKeys) {
logger.debug("Migrating legacy sender keys to database");
long start = System.nanoTime();
try (final var connection = database.getConnection()) {
connection.setAutoCommit(false);
for (final var pair : senderKeys) {
storeSenderKey(connection, pair.first(), pair.second());
}
connection.commit();
} catch (SQLException e) {
throw new RuntimeException("Failed update sender keys store", e);
}
logger.debug("Complete sender keys migration took {}ms", (System.nanoTime() - start) / 1000000);
}
/** /**
* @param identifier can be either a serialized uuid or an e164 phone number * @param identifier can be either a serialized uuid or an e164 phone number
*/ */
@ -145,106 +180,86 @@ public class SenderKeyRecordStore implements SenderKeyStore {
return resolver.resolveRecipient(identifier); return resolver.resolveRecipient(identifier);
} }
private Key getKey(final RecipientId recipientId, int deviceId, final UUID distributionId) {
return new Key(recipientId, deviceId, distributionId);
}
private Key getKey(final SignalProtocolAddress address, final UUID distributionId) { private Key getKey(final SignalProtocolAddress address, final UUID distributionId) {
final var recipientId = resolveRecipient(address.getName()); final var recipientId = resolveRecipient(address.getName());
return new Key(recipientId, address.getDeviceId(), distributionId); return new Key(recipientId, address.getDeviceId(), distributionId);
} }
private List<Key> getKeysLocked(RecipientId recipientId) { private SenderKeyRecord loadSenderKey(final Connection connection, final Key key) throws SQLException {
final var files = senderKeysPath.listFiles((_file, s) -> s.startsWith(recipientId.id() + "_")); final var sql = (
if (files == null) { """
return List.of(); SELECT s.record
FROM %s AS s
WHERE s.recipient_id = ? AND s.device_id = ? AND s.distribution_id = ?
"""
).formatted(TABLE_SENDER_KEY);
try (final var statement = connection.prepareStatement(sql)) {
statement.setLong(1, key.recipientId().id());
statement.setInt(2, key.deviceId());
statement.setBytes(3, UuidUtil.toByteArray(key.distributionId()));
return Utils.executeQueryForOptional(statement, this::getSenderKeyRecordFromResultSet).orElse(null);
} }
return parseFileNames(files);
} }
final Pattern senderKeyFileNamePattern = Pattern.compile("(\\d+)_(\\d+)_([\\da-z\\-]+)"); private void storeSenderKey(
final Connection connection, final Key key, final SenderKeyRecord senderKeyRecord
private List<Key> parseFileNames(final File[] files) { ) throws SQLException {
return Arrays.stream(files) final var sqlUpdate = """
.map(f -> senderKeyFileNamePattern.matcher(f.getName())) UPDATE %s
.filter(Matcher::matches) SET record = ?
.map(matcher -> { WHERE recipient_id = ? AND device_id = ? and distribution_id = ?
final var recipientId = resolver.resolveRecipient(Long.parseLong(matcher.group(1))); """.formatted(TABLE_SENDER_KEY);
if (recipientId == null) { try (final var statement = connection.prepareStatement(sqlUpdate)) {
return null; statement.setBytes(1, senderKeyRecord.serialize());
} statement.setLong(2, key.recipientId().id());
return new Key(recipientId, Integer.parseInt(matcher.group(2)), UUID.fromString(matcher.group(3))); statement.setLong(3, key.deviceId());
}) statement.setBytes(4, UuidUtil.toByteArray(key.distributionId()));
.filter(Objects::nonNull) final var rows = statement.executeUpdate();
.toList(); if (rows > 0) {
} return;
private File getSenderKeyFile(Key key) {
try {
IOUtils.createPrivateDirectories(senderKeysPath);
} catch (IOException e) {
throw new AssertionError("Failed to create sender keys path: " + e.getMessage(), e);
}
return new File(senderKeysPath,
key.recipientId().id() + "_" + key.deviceId() + "_" + key.distributionId.toString());
}
private SenderKeyRecord loadSenderKeyLocked(final Key key) {
{
final var senderKeyRecord = cachedSenderKeys.get(key);
if (senderKeyRecord != null) {
return senderKeyRecord;
} }
} }
final var file = getSenderKeyFile(key); // Record doesn't exist yet, creating a new one
if (!file.exists()) { final var sqlInsert = (
return null; """
INSERT OR REPLACE INTO %s (recipient_id, device_id, distribution_id, record, created_timestamp)
VALUES (?, ?, ?, ?, ?)
"""
).formatted(TABLE_SENDER_KEY);
try (final var statement = connection.prepareStatement(sqlInsert)) {
statement.setLong(1, key.recipientId().id());
statement.setInt(2, key.deviceId());
statement.setBytes(3, UuidUtil.toByteArray(key.distributionId()));
statement.setBytes(4, senderKeyRecord.serialize());
statement.setLong(5, System.currentTimeMillis());
statement.executeUpdate();
} }
try (var inputStream = new FileInputStream(file)) { }
final var senderKeyRecord = new SenderKeyRecord(inputStream.readAllBytes());
cachedSenderKeys.put(key, senderKeyRecord); private void deleteAllFor(final Connection connection, final RecipientId recipientId) throws SQLException {
return senderKeyRecord; final var sql = (
} catch (IOException | InvalidMessageException e) { """
logger.warn("Failed to load sender key, resetting sender key: {}", e.getMessage()); DELETE FROM %s AS s
WHERE s.recipient_id = ?
"""
).formatted(TABLE_SENDER_KEY);
try (final var statement = connection.prepareStatement(sql)) {
statement.setLong(1, recipientId.id());
statement.executeUpdate();
}
}
private SenderKeyRecord getSenderKeyRecordFromResultSet(ResultSet resultSet) throws SQLException {
try {
final var record = resultSet.getBytes("record");
return new SenderKeyRecord(record);
} catch (InvalidMessageException e) {
logger.warn("Failed to load sender key, resetting: {}", e.getMessage());
return null; return null;
} }
} }
private void storeSenderKeyLocked(final Key key, final SenderKeyRecord senderKeyRecord) { record Key(RecipientId recipientId, int deviceId, UUID distributionId) {}
cachedSenderKeys.put(key, senderKeyRecord);
final var file = getSenderKeyFile(key);
try {
try (var outputStream = new FileOutputStream(file)) {
outputStream.write(senderKeyRecord.serialize());
}
} catch (IOException e) {
logger.warn("Failed to store sender key, trying to delete file and retry: {}", e.getMessage());
try {
Files.delete(file.toPath());
try (var outputStream = new FileOutputStream(file)) {
outputStream.write(senderKeyRecord.serialize());
}
} catch (IOException e2) {
logger.error("Failed to store sender key file {}: {}", file, e2.getMessage());
}
}
}
private void deleteSenderKeyLocked(final Key key) {
cachedSenderKeys.remove(key);
final var file = getSenderKeyFile(key);
if (!file.exists()) {
return;
}
try {
Files.delete(file.toPath());
} catch (IOException e) {
logger.error("Failed to delete sender key file {}: {}", file, e.getMessage());
}
}
private record Key(RecipientId recipientId, int deviceId, UUID distributionId) {}
} }

View file

@ -1,10 +1,10 @@
package org.asamk.signal.manager.storage.senderKeys; package org.asamk.signal.manager.storage.senderKeys;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.asamk.signal.manager.helper.RecipientAddressResolver; import org.asamk.signal.manager.helper.RecipientAddressResolver;
import org.asamk.signal.manager.storage.Database;
import org.asamk.signal.manager.storage.Utils; import org.asamk.signal.manager.storage.Utils;
import org.asamk.signal.manager.storage.recipients.RecipientId; import org.asamk.signal.manager.storage.recipients.RecipientId;
import org.asamk.signal.manager.storage.recipients.RecipientIdCreator;
import org.asamk.signal.manager.storage.recipients.RecipientResolver; import org.asamk.signal.manager.storage.recipients.RecipientResolver;
import org.signal.libsignal.protocol.SignalProtocolAddress; import org.signal.libsignal.protocol.SignalProtocolAddress;
import org.slf4j.Logger; import org.slf4j.Logger;
@ -12,94 +12,70 @@ import org.slf4j.LoggerFactory;
import org.whispersystems.signalservice.api.push.DistributionId; import org.whispersystems.signalservice.api.push.DistributionId;
import org.whispersystems.signalservice.api.util.UuidUtil; import org.whispersystems.signalservice.api.util.UuidUtil;
import java.io.ByteArrayInputStream; import java.sql.Connection;
import java.io.ByteArrayOutputStream; import java.sql.ResultSet;
import java.io.File; import java.sql.SQLException;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.IOException;
import java.util.Collection; import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Set; import java.util.Set;
import java.util.UUID;
import java.util.stream.Collectors; import java.util.stream.Collectors;
public class SenderKeySharedStore { public class SenderKeySharedStore {
private final static Logger logger = LoggerFactory.getLogger(SenderKeySharedStore.class); private final static Logger logger = LoggerFactory.getLogger(SenderKeySharedStore.class);
private final static String TABLE_SENDER_KEY_SHARED = "sender_key_shared";
private final Map<UUID, Set<SenderKeySharedEntry>> sharedSenderKeys; private final Database database;
private final RecipientIdCreator recipientIdCreator;
private final ObjectMapper objectMapper;
private final File file;
private final RecipientResolver resolver; private final RecipientResolver resolver;
private final RecipientAddressResolver addressResolver; private final RecipientAddressResolver addressResolver;
public static SenderKeySharedStore load( public static void createSql(Connection connection) throws SQLException {
final File file, final RecipientAddressResolver addressResolver, final RecipientResolver resolver // When modifying the CREATE statement here, also add a migration in AccountDatabase.java
) { try (final var statement = connection.createStatement()) {
final var objectMapper = Utils.createStorageObjectMapper(); statement.executeUpdate("""
try (var inputStream = new FileInputStream(file)) { CREATE TABLE sender_key_shared (
final var storage = objectMapper.readValue(inputStream, Storage.class); _id INTEGER PRIMARY KEY,
final var sharedSenderKeys = new HashMap<UUID, Set<SenderKeySharedEntry>>(); recipient_id INTEGER NOT NULL REFERENCES recipient (_id) ON DELETE CASCADE,
for (final var senderKey : storage.sharedSenderKeys) { device_id INTEGER NOT NULL,
final var recipientId = resolver.resolveRecipient(senderKey.recipientId); distribution_id BLOB NOT NULL,
if (recipientId == null) { timestamp INTEGER NOT NULL,
continue; UNIQUE(recipient_id, device_id, distribution_id)
} );
final var entry = new SenderKeySharedEntry(recipientId, senderKey.deviceId); """);
final var distributionId = UuidUtil.parseOrNull(senderKey.distributionId);
if (distributionId == null) {
logger.warn("Read invalid distribution id from storage {}, ignoring", senderKey.distributionId);
continue;
}
var entries = sharedSenderKeys.get(distributionId);
if (entries == null) {
entries = new HashSet<>();
}
entries.add(entry);
sharedSenderKeys.put(distributionId, entries);
}
return new SenderKeySharedStore(sharedSenderKeys, objectMapper, file, addressResolver, resolver);
} catch (FileNotFoundException e) {
logger.trace("Creating new shared sender key store.");
return new SenderKeySharedStore(new HashMap<>(), objectMapper, file, addressResolver, resolver);
} catch (IOException e) {
logger.warn("Failed to load shared sender key store", e);
throw new RuntimeException(e);
} }
} }
private SenderKeySharedStore( SenderKeySharedStore(
final Map<UUID, Set<SenderKeySharedEntry>> sharedSenderKeys, final Database database,
final ObjectMapper objectMapper, final RecipientIdCreator recipientIdCreator,
final File file,
final RecipientAddressResolver addressResolver, final RecipientAddressResolver addressResolver,
final RecipientResolver resolver final RecipientResolver resolver
) { ) {
this.sharedSenderKeys = sharedSenderKeys; this.database = database;
this.objectMapper = objectMapper; this.recipientIdCreator = recipientIdCreator;
this.file = file;
this.addressResolver = addressResolver; this.addressResolver = addressResolver;
this.resolver = resolver; this.resolver = resolver;
} }
public Set<SignalProtocolAddress> getSenderKeySharedWith(final DistributionId distributionId) { public Set<SignalProtocolAddress> getSenderKeySharedWith(final DistributionId distributionId) {
synchronized (sharedSenderKeys) { try (final var connection = database.getConnection()) {
final var addresses = sharedSenderKeys.get(distributionId.asUuid()); final var sql = (
if (addresses == null) { """
return Set.of(); SELECT s.recipient_id, s.device_id
FROM %s AS s
WHERE s.distribution_id = ?
"""
).formatted(TABLE_SENDER_KEY_SHARED);
try (final var statement = connection.prepareStatement(sql)) {
statement.setBytes(1, UuidUtil.toByteArray(distributionId.asUuid()));
return Utils.executeQueryForStream(statement, this::getSenderKeySharedEntryFromResultSet)
.map(k -> new SignalProtocolAddress(addressResolver.resolveRecipientAddress(k.recipientId())
.getIdentifier(), k.deviceId()))
.collect(Collectors.toSet());
} }
return addresses.stream() } catch (SQLException e) {
.map(k -> new SignalProtocolAddress(addressResolver.resolveRecipientAddress(k.recipientId()) throw new RuntimeException("Failed read from shared sender key store", e);
.getIdentifier(), k.deviceId()))
.collect(Collectors.toSet());
} }
} }
@ -107,135 +83,173 @@ public class SenderKeySharedStore {
final DistributionId distributionId, final Collection<SignalProtocolAddress> addresses final DistributionId distributionId, final Collection<SignalProtocolAddress> addresses
) { ) {
final var newEntries = addresses.stream() final var newEntries = addresses.stream()
.map(a -> new SenderKeySharedEntry(resolveRecipient(a.getName()), a.getDeviceId())) .map(a -> new SenderKeySharedEntry(resolver.resolveRecipient(a.getName()), a.getDeviceId()))
.collect(Collectors.toSet()); .collect(Collectors.toSet());
synchronized (sharedSenderKeys) { try (final var connection = database.getConnection()) {
final var previousEntries = sharedSenderKeys.getOrDefault(distributionId.asUuid(), Set.of()); connection.setAutoCommit(false);
markSenderKeysSharedWith(connection, distributionId, newEntries);
sharedSenderKeys.put(distributionId.asUuid(), new HashSet<>() { connection.commit();
{ } catch (SQLException e) {
addAll(previousEntries); throw new RuntimeException("Failed update shared sender key store", e);
addAll(newEntries);
}
});
saveLocked();
} }
} }
public void clearSenderKeySharedWith(final Collection<SignalProtocolAddress> addresses) { public void clearSenderKeySharedWith(final Collection<SignalProtocolAddress> addresses) {
final var entriesToDelete = addresses.stream() final var entriesToDelete = addresses.stream()
.map(a -> new SenderKeySharedEntry(resolveRecipient(a.getName()), a.getDeviceId())) .map(a -> new SenderKeySharedEntry(resolver.resolveRecipient(a.getName()), a.getDeviceId()))
.collect(Collectors.toSet()); .collect(Collectors.toSet());
synchronized (sharedSenderKeys) { try (final var connection = database.getConnection()) {
for (final var distributionId : sharedSenderKeys.keySet()) { connection.setAutoCommit(false);
final var entries = sharedSenderKeys.getOrDefault(distributionId, Set.of()); final var sql = (
"""
sharedSenderKeys.put(distributionId, new HashSet<>(entries) { DELETE FROM %s AS s
{ WHERE recipient_id = ? AND device_id = ?
removeAll(entriesToDelete); """
} ).formatted(TABLE_SENDER_KEY_SHARED);
}); try (final var statement = connection.prepareStatement(sql)) {
for (final var entry : entriesToDelete) {
statement.setLong(1, entry.recipientId().id());
statement.setInt(2, entry.deviceId());
statement.executeUpdate();
}
} }
saveLocked(); connection.commit();
} catch (SQLException e) {
throw new RuntimeException("Failed update shared sender key store", e);
} }
} }
public void deleteAll() { public void deleteAll() {
synchronized (sharedSenderKeys) { try (final var connection = database.getConnection()) {
sharedSenderKeys.clear(); final var sql = (
saveLocked(); """
DELETE FROM %s AS s
"""
).formatted(TABLE_SENDER_KEY_SHARED);
try (final var statement = connection.prepareStatement(sql)) {
statement.executeUpdate();
}
} catch (SQLException e) {
throw new RuntimeException("Failed update shared sender key store", e);
} }
} }
public void deleteAllFor(final RecipientId recipientId) { public void deleteAllFor(final RecipientId recipientId) {
synchronized (sharedSenderKeys) { try (final var connection = database.getConnection()) {
for (final var distributionId : sharedSenderKeys.keySet()) { final var sql = (
final var entries = sharedSenderKeys.getOrDefault(distributionId, Set.of()); """
DELETE FROM %s AS s
sharedSenderKeys.put(distributionId, new HashSet<>(entries) { WHERE recipient_id = ?
{ """
removeIf(e -> e.recipientId().equals(recipientId)); ).formatted(TABLE_SENDER_KEY_SHARED);
} try (final var statement = connection.prepareStatement(sql)) {
}); statement.setLong(1, recipientId.id());
statement.executeUpdate();
} }
saveLocked(); } catch (SQLException e) {
throw new RuntimeException("Failed update shared sender key store", e);
} }
} }
public void deleteSharedWith( public void deleteSharedWith(
final RecipientId recipientId, final int deviceId, final DistributionId distributionId final RecipientId recipientId, final int deviceId, final DistributionId distributionId
) { ) {
synchronized (sharedSenderKeys) { try (final var connection = database.getConnection()) {
final var entries = sharedSenderKeys.getOrDefault(distributionId.asUuid(), Set.of()); final var sql = (
"""
sharedSenderKeys.put(distributionId.asUuid(), new HashSet<>(entries) { DELETE FROM %s AS s
{ WHERE recipient_id = ? AND device_id = ? AND distribution_id = ?
remove(new SenderKeySharedEntry(recipientId, deviceId)); """
} ).formatted(TABLE_SENDER_KEY_SHARED);
}); try (final var statement = connection.prepareStatement(sql)) {
saveLocked(); statement.setLong(1, recipientId.id());
statement.setInt(2, deviceId);
statement.setBytes(3, UuidUtil.toByteArray(distributionId.asUuid()));
statement.executeUpdate();
}
} catch (SQLException e) {
throw new RuntimeException("Failed update shared sender key store", e);
} }
} }
public void deleteAllFor(final DistributionId distributionId) { public void deleteAllFor(final DistributionId distributionId) {
synchronized (sharedSenderKeys) { try (final var connection = database.getConnection()) {
if (sharedSenderKeys.remove(distributionId.asUuid()) != null) { final var sql = (
saveLocked(); """
DELETE FROM %s AS s
WHERE distribution_id = ?
"""
).formatted(TABLE_SENDER_KEY_SHARED);
try (final var statement = connection.prepareStatement(sql)) {
statement.setBytes(1, UuidUtil.toByteArray(distributionId.asUuid()));
statement.executeUpdate();
} }
} catch (SQLException e) {
throw new RuntimeException("Failed update shared sender key store", e);
} }
} }
public void mergeRecipients(RecipientId recipientId, RecipientId toBeMergedRecipientId) { public void mergeRecipients(RecipientId recipientId, RecipientId toBeMergedRecipientId) {
synchronized (sharedSenderKeys) { try (final var connection = database.getConnection()) {
for (final var distributionId : sharedSenderKeys.keySet()) { final var sql = (
final var entries = sharedSenderKeys.getOrDefault(distributionId, Set.of()); """
UPDATE OR REPLACE %s
sharedSenderKeys.put(distributionId, SET recipient_id = ?
entries.stream() WHERE recipient_id = ?
.map(e -> e.recipientId.equals(toBeMergedRecipientId) ? new SenderKeySharedEntry( """
recipientId, ).formatted(TABLE_SENDER_KEY_SHARED);
e.deviceId()) : e) try (final var statement = connection.prepareStatement(sql)) {
.collect(Collectors.toSet())); statement.setLong(1, recipientId.id());
statement.setLong(2, toBeMergedRecipientId.id());
statement.executeUpdate();
} }
saveLocked(); } catch (SQLException e) {
throw new RuntimeException("Failed update shared sender key store", e);
} }
} }
/** void addLegacySenderKeysShared(final Map<DistributionId, Set<SenderKeySharedEntry>> sharedSenderKeys) {
* @param identifier can be either a serialized uuid or a e164 phone number logger.debug("Migrating legacy sender keys shared to database");
*/ long start = System.nanoTime();
private RecipientId resolveRecipient(String identifier) { try (final var connection = database.getConnection()) {
return resolver.resolveRecipient(identifier); connection.setAutoCommit(false);
for (final var entry : sharedSenderKeys.entrySet()) {
markSenderKeysSharedWith(connection, entry.getKey(), entry.getValue());
}
connection.commit();
} catch (SQLException e) {
throw new RuntimeException("Failed update shared sender key store", e);
}
logger.debug("Complete sender keys shared migration took {}ms", (System.nanoTime() - start) / 1000000);
} }
private void saveLocked() { private void markSenderKeysSharedWith(
var storage = new Storage(sharedSenderKeys.entrySet().stream().flatMap(pair -> { final Connection connection, final DistributionId distributionId, final Set<SenderKeySharedEntry> newEntries
final var sharedWith = pair.getValue(); ) throws SQLException {
return sharedWith.stream() final var sql = (
.map(entry -> new Storage.SharedSenderKey(entry.recipientId().id(), """
entry.deviceId(), INSERT OR REPLACE INTO %s (recipient_id, device_id, distribution_id, timestamp)
pair.getKey().toString())); VALUES (?, ?, ?, ?)
}).toList()); """
).formatted(TABLE_SENDER_KEY_SHARED);
// Write to memory first to prevent corrupting the file in case of serialization errors try (final var statement = connection.prepareStatement(sql)) {
try (var inMemoryOutput = new ByteArrayOutputStream()) { for (final var entry : newEntries) {
objectMapper.writeValue(inMemoryOutput, storage); statement.setLong(1, entry.recipientId().id());
statement.setInt(2, entry.deviceId());
var input = new ByteArrayInputStream(inMemoryOutput.toByteArray()); statement.setBytes(3, UuidUtil.toByteArray(distributionId.asUuid()));
try (var outputStream = new FileOutputStream(file)) { statement.setLong(4, System.currentTimeMillis());
input.transferTo(outputStream); statement.executeUpdate();
} }
} catch (Exception e) {
logger.error("Error saving shared sender key store file: {}", e.getMessage());
} }
} }
private record Storage(List<SharedSenderKey> sharedSenderKeys) { private SenderKeySharedEntry getSenderKeySharedEntryFromResultSet(ResultSet resultSet) throws SQLException {
final var recipientId = resultSet.getLong("recipient_id");
private record SharedSenderKey(long recipientId, int deviceId, String distributionId) {} final var deviceId = resultSet.getInt("device_id");
return new SenderKeySharedEntry(recipientIdCreator.create(recipientId), deviceId);
} }
private record SenderKeySharedEntry(RecipientId recipientId, int deviceId) {} record SenderKeySharedEntry(RecipientId recipientId, int deviceId) {}
} }

View file

@ -1,15 +1,18 @@
package org.asamk.signal.manager.storage.senderKeys; package org.asamk.signal.manager.storage.senderKeys;
import org.asamk.signal.manager.api.Pair;
import org.asamk.signal.manager.helper.RecipientAddressResolver; import org.asamk.signal.manager.helper.RecipientAddressResolver;
import org.asamk.signal.manager.storage.Database;
import org.asamk.signal.manager.storage.recipients.RecipientId; import org.asamk.signal.manager.storage.recipients.RecipientId;
import org.asamk.signal.manager.storage.recipients.RecipientIdCreator;
import org.asamk.signal.manager.storage.recipients.RecipientResolver; import org.asamk.signal.manager.storage.recipients.RecipientResolver;
import org.signal.libsignal.protocol.SignalProtocolAddress; import org.signal.libsignal.protocol.SignalProtocolAddress;
import org.signal.libsignal.protocol.groups.state.SenderKeyRecord; import org.signal.libsignal.protocol.groups.state.SenderKeyRecord;
import org.whispersystems.signalservice.api.SignalServiceSenderKeyStore; import org.whispersystems.signalservice.api.SignalServiceSenderKeyStore;
import org.whispersystems.signalservice.api.push.DistributionId; import org.whispersystems.signalservice.api.push.DistributionId;
import java.io.File;
import java.util.Collection; import java.util.Collection;
import java.util.Map;
import java.util.Set; import java.util.Set;
import java.util.UUID; import java.util.UUID;
@ -19,13 +22,13 @@ public class SenderKeyStore implements SignalServiceSenderKeyStore {
private final SenderKeySharedStore senderKeySharedStore; private final SenderKeySharedStore senderKeySharedStore;
public SenderKeyStore( public SenderKeyStore(
final File file, final Database database,
final File senderKeysPath,
final RecipientAddressResolver addressResolver, final RecipientAddressResolver addressResolver,
final RecipientResolver resolver final RecipientResolver resolver,
final RecipientIdCreator recipientIdCreator
) { ) {
this.senderKeyRecordStore = new SenderKeyRecordStore(senderKeysPath, resolver); this.senderKeyRecordStore = new SenderKeyRecordStore(database, resolver);
this.senderKeySharedStore = SenderKeySharedStore.load(file, addressResolver, resolver); this.senderKeySharedStore = new SenderKeySharedStore(database, recipientIdCreator, addressResolver, resolver);
} }
@Override @Override
@ -88,4 +91,12 @@ public class SenderKeyStore implements SignalServiceSenderKeyStore {
senderKeySharedStore.mergeRecipients(recipientId, toBeMergedRecipientId); senderKeySharedStore.mergeRecipients(recipientId, toBeMergedRecipientId);
senderKeyRecordStore.mergeRecipients(recipientId, toBeMergedRecipientId); senderKeyRecordStore.mergeRecipients(recipientId, toBeMergedRecipientId);
} }
void addLegacySenderKeys(final Collection<Pair<SenderKeyRecordStore.Key, SenderKeyRecord>> senderKeys) {
senderKeyRecordStore.addLegacySenderKeys(senderKeys);
}
void addLegacySenderKeysShared(final Map<DistributionId, Set<SenderKeySharedStore.SenderKeySharedEntry>> sharedSenderKeys) {
senderKeySharedStore.addLegacySenderKeysShared(sharedSenderKeys);
}
} }