mirror of
https://github.com/AsamK/signal-cli
synced 2025-08-29 18:40:39 +00:00
Move session store to database
This commit is contained in:
parent
65c9a2e185
commit
484daa4c69
4 changed files with 413 additions and 176 deletions
|
@ -7,6 +7,7 @@ import org.asamk.signal.manager.storage.prekeys.PreKeyStore;
|
|||
import org.asamk.signal.manager.storage.prekeys.SignedPreKeyStore;
|
||||
import org.asamk.signal.manager.storage.recipients.RecipientStore;
|
||||
import org.asamk.signal.manager.storage.sendLog.MessageSendLogStore;
|
||||
import org.asamk.signal.manager.storage.sessions.SessionStore;
|
||||
import org.asamk.signal.manager.storage.stickers.StickerStore;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
@ -18,7 +19,7 @@ import java.sql.SQLException;
|
|||
public class AccountDatabase extends Database {
|
||||
|
||||
private final static Logger logger = LoggerFactory.getLogger(AccountDatabase.class);
|
||||
private static final long DATABASE_VERSION = 5;
|
||||
private static final long DATABASE_VERSION = 6;
|
||||
|
||||
private AccountDatabase(final HikariDataSource dataSource) {
|
||||
super(logger, DATABASE_VERSION, dataSource);
|
||||
|
@ -36,6 +37,7 @@ public class AccountDatabase extends Database {
|
|||
PreKeyStore.createSql(connection);
|
||||
SignedPreKeyStore.createSql(connection);
|
||||
GroupStore.createSql(connection);
|
||||
SessionStore.createSql(connection);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -143,5 +145,20 @@ public class AccountDatabase extends Database {
|
|||
""");
|
||||
}
|
||||
}
|
||||
if (oldVersion < 6) {
|
||||
logger.debug("Updating database: Creating session tables");
|
||||
try (final var statement = connection.createStatement()) {
|
||||
statement.executeUpdate("""
|
||||
CREATE TABLE session (
|
||||
_id INTEGER PRIMARY KEY,
|
||||
account_id_type INTEGER NOT NULL,
|
||||
recipient_id INTEGER NOT NULL REFERENCES recipient (_id) ON DELETE CASCADE,
|
||||
device_id INTEGER NOT NULL,
|
||||
record BLOB NOT NULL,
|
||||
UNIQUE(account_id_type, recipient_id, device_id)
|
||||
);
|
||||
""");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -38,6 +38,7 @@ import org.asamk.signal.manager.storage.recipients.RecipientStore;
|
|||
import org.asamk.signal.manager.storage.recipients.RecipientTrustedResolver;
|
||||
import org.asamk.signal.manager.storage.sendLog.MessageSendLogStore;
|
||||
import org.asamk.signal.manager.storage.senderKeys.SenderKeyStore;
|
||||
import org.asamk.signal.manager.storage.sessions.LegacySessionStore;
|
||||
import org.asamk.signal.manager.storage.sessions.SessionStore;
|
||||
import org.asamk.signal.manager.storage.stickers.LegacyStickerStore;
|
||||
import org.asamk.signal.manager.storage.stickers.StickerStore;
|
||||
|
@ -634,6 +635,11 @@ public class SignalAccount implements Closeable {
|
|||
LegacySignedPreKeyStore.migrate(legacyPniSignedPreKeysPath, getPniSignedPreKeyStore());
|
||||
migratedLegacyConfig = true;
|
||||
}
|
||||
final var legacySessionsPath = getSessionsPath(dataPath, accountPath);
|
||||
if (legacySessionsPath.exists()) {
|
||||
LegacySessionStore.migrate(legacySessionsPath, getRecipientResolver(), getSessionStore());
|
||||
migratedLegacyConfig = true;
|
||||
}
|
||||
final var legacySignalProtocolStore = rootNode.hasNonNull("axolotlStore")
|
||||
? jsonProcessor.convertValue(Utils.getNotNullNode(rootNode, "axolotlStore"),
|
||||
LegacyJsonSignalProtocolStore.class)
|
||||
|
@ -1067,7 +1073,10 @@ public class SignalAccount implements Closeable {
|
|||
|
||||
public SessionStore getSessionStore() {
|
||||
return getOrCreate(() -> sessionStore,
|
||||
() -> sessionStore = new SessionStore(getSessionsPath(dataPath, accountPath), getRecipientResolver()));
|
||||
() -> sessionStore = new SessionStore(getAccountDatabase(),
|
||||
ServiceIdType.ACI,
|
||||
getRecipientResolver(),
|
||||
getRecipientIdCreator()));
|
||||
}
|
||||
|
||||
public IdentityKeyStore getIdentityKeyStore() {
|
||||
|
|
|
@ -0,0 +1,107 @@
|
|||
package org.asamk.signal.manager.storage.sessions;
|
||||
|
||||
import org.asamk.signal.manager.api.Pair;
|
||||
import org.asamk.signal.manager.storage.recipients.RecipientResolver;
|
||||
import org.asamk.signal.manager.storage.sessions.SessionStore.Key;
|
||||
import org.asamk.signal.manager.util.IOUtils;
|
||||
import org.signal.libsignal.protocol.state.SessionRecord;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.FileInputStream;
|
||||
import java.io.IOException;
|
||||
import java.nio.file.Files;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collection;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.regex.Matcher;
|
||||
import java.util.regex.Pattern;
|
||||
|
||||
public class LegacySessionStore {
|
||||
|
||||
private final static Logger logger = LoggerFactory.getLogger(LegacySessionStore.class);
|
||||
|
||||
public static void migrate(
|
||||
final File sessionsPath, final RecipientResolver resolver, final SessionStore sessionStore
|
||||
) {
|
||||
final var keys = getKeysLocked(sessionsPath, resolver);
|
||||
final var sessions = keys.stream().map(key -> {
|
||||
final var record = loadSessionLocked(key, sessionsPath);
|
||||
if (record == null) {
|
||||
return null;
|
||||
}
|
||||
return new Pair<>(key, record);
|
||||
}).filter(Objects::nonNull).toList();
|
||||
sessionStore.addLegacySessions(sessions);
|
||||
deleteAllSessions(sessionsPath);
|
||||
}
|
||||
|
||||
private static void deleteAllSessions(File sessionsPath) {
|
||||
final var files = sessionsPath.listFiles();
|
||||
if (files == null) {
|
||||
return;
|
||||
}
|
||||
|
||||
for (var file : files) {
|
||||
try {
|
||||
Files.delete(file.toPath());
|
||||
} catch (IOException e) {
|
||||
logger.error("Failed to delete session file {}: {}", file, e.getMessage());
|
||||
}
|
||||
}
|
||||
try {
|
||||
Files.delete(sessionsPath.toPath());
|
||||
} catch (IOException e) {
|
||||
logger.error("Failed to delete session directory {}: {}", sessionsPath, e.getMessage());
|
||||
}
|
||||
}
|
||||
|
||||
private static Collection<Key> getKeysLocked(File sessionsPath, final RecipientResolver resolver) {
|
||||
final var files = sessionsPath.listFiles();
|
||||
if (files == null) {
|
||||
return List.of();
|
||||
}
|
||||
return parseFileNames(files, resolver);
|
||||
}
|
||||
|
||||
static final Pattern sessionFileNamePattern = Pattern.compile("(\\d+)_(\\d+)");
|
||||
|
||||
private static List<Key> parseFileNames(final File[] files, final RecipientResolver resolver) {
|
||||
return Arrays.stream(files)
|
||||
.map(f -> sessionFileNamePattern.matcher(f.getName()))
|
||||
.filter(Matcher::matches)
|
||||
.map(matcher -> {
|
||||
final var recipientId = resolver.resolveRecipient(Long.parseLong(matcher.group(1)));
|
||||
if (recipientId == null) {
|
||||
return null;
|
||||
}
|
||||
return new Key(recipientId, Integer.parseInt(matcher.group(2)));
|
||||
})
|
||||
.filter(Objects::nonNull)
|
||||
.toList();
|
||||
}
|
||||
|
||||
private static File getSessionFile(Key key, final File sessionsPath) {
|
||||
try {
|
||||
IOUtils.createPrivateDirectories(sessionsPath);
|
||||
} catch (IOException e) {
|
||||
throw new AssertionError("Failed to create sessions path", e);
|
||||
}
|
||||
return new File(sessionsPath, key.recipientId().id() + "_" + key.deviceId());
|
||||
}
|
||||
|
||||
private static SessionRecord loadSessionLocked(final Key key, final File sessionsPath) {
|
||||
final var file = getSessionFile(key, sessionsPath);
|
||||
if (!file.exists()) {
|
||||
return null;
|
||||
}
|
||||
try (var inputStream = new FileInputStream(file)) {
|
||||
return new SessionRecord(inputStream.readAllBytes());
|
||||
} catch (Exception e) {
|
||||
logger.warn("Failed to load session, resetting session: {}", e.getMessage());
|
||||
return null;
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,8 +1,12 @@
|
|||
package org.asamk.signal.manager.storage.sessions;
|
||||
|
||||
import org.asamk.signal.manager.api.Pair;
|
||||
import org.asamk.signal.manager.storage.Database;
|
||||
import org.asamk.signal.manager.storage.Utils;
|
||||
import org.asamk.signal.manager.storage.recipients.RecipientId;
|
||||
import org.asamk.signal.manager.storage.recipients.RecipientIdCreator;
|
||||
import org.asamk.signal.manager.storage.recipients.RecipientResolver;
|
||||
import org.asamk.signal.manager.util.IOUtils;
|
||||
import org.signal.libsignal.protocol.InvalidMessageException;
|
||||
import org.signal.libsignal.protocol.NoSessionException;
|
||||
import org.signal.libsignal.protocol.SignalProtocolAddress;
|
||||
import org.signal.libsignal.protocol.ecc.ECPublicKey;
|
||||
|
@ -11,50 +15,68 @@ import org.signal.libsignal.protocol.state.SessionRecord;
|
|||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.whispersystems.signalservice.api.SignalServiceSessionStore;
|
||||
import org.whispersystems.signalservice.api.push.ServiceIdType;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.FileInputStream;
|
||||
import java.io.FileOutputStream;
|
||||
import java.io.IOException;
|
||||
import java.nio.file.Files;
|
||||
import java.util.Arrays;
|
||||
import java.sql.Connection;
|
||||
import java.sql.ResultSet;
|
||||
import java.sql.SQLException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collection;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.regex.Matcher;
|
||||
import java.util.regex.Pattern;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
public class SessionStore implements SignalServiceSessionStore {
|
||||
|
||||
private static final String TABLE_SESSION = "session";
|
||||
private final static Logger logger = LoggerFactory.getLogger(SessionStore.class);
|
||||
|
||||
private final Map<Key, SessionRecord> cachedSessions = new HashMap<>();
|
||||
|
||||
private final File sessionsPath;
|
||||
|
||||
private final Database database;
|
||||
private final int accountIdType;
|
||||
private final RecipientResolver resolver;
|
||||
private final RecipientIdCreator recipientIdCreator;
|
||||
|
||||
public static void createSql(Connection connection) throws SQLException {
|
||||
// When modifying the CREATE statement here, also add a migration in AccountDatabase.java
|
||||
try (final var statement = connection.createStatement()) {
|
||||
statement.executeUpdate("""
|
||||
CREATE TABLE session (
|
||||
_id INTEGER PRIMARY KEY,
|
||||
account_id_type INTEGER NOT NULL,
|
||||
recipient_id INTEGER NOT NULL REFERENCES recipient (_id) ON DELETE CASCADE,
|
||||
device_id INTEGER NOT NULL,
|
||||
record BLOB NOT NULL,
|
||||
UNIQUE(account_id_type, recipient_id, device_id)
|
||||
);
|
||||
""");
|
||||
}
|
||||
}
|
||||
|
||||
public SessionStore(
|
||||
final File sessionsPath, final RecipientResolver resolver
|
||||
final Database database,
|
||||
final ServiceIdType serviceIdType,
|
||||
final RecipientResolver resolver,
|
||||
final RecipientIdCreator recipientIdCreator
|
||||
) {
|
||||
this.sessionsPath = sessionsPath;
|
||||
this.database = database;
|
||||
this.accountIdType = Utils.getAccountIdType(serviceIdType);
|
||||
this.resolver = resolver;
|
||||
this.recipientIdCreator = recipientIdCreator;
|
||||
}
|
||||
|
||||
@Override
|
||||
public SessionRecord loadSession(SignalProtocolAddress address) {
|
||||
final var key = getKey(address);
|
||||
|
||||
synchronized (cachedSessions) {
|
||||
final var session = loadSessionLocked(key);
|
||||
if (session == null) {
|
||||
return new SessionRecord();
|
||||
}
|
||||
return session;
|
||||
try (final var connection = database.getConnection()) {
|
||||
final var session = loadSession(connection, key);
|
||||
return Objects.requireNonNullElseGet(session, SessionRecord::new);
|
||||
} catch (SQLException e) {
|
||||
throw new RuntimeException("Failed read from session store", e);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -62,8 +84,14 @@ public class SessionStore implements SignalServiceSessionStore {
|
|||
public List<SessionRecord> loadExistingSessions(final List<SignalProtocolAddress> addresses) throws NoSessionException {
|
||||
final var keys = addresses.stream().map(this::getKey).toList();
|
||||
|
||||
synchronized (cachedSessions) {
|
||||
final var sessions = keys.stream().map(this::loadSessionLocked).filter(Objects::nonNull).toList();
|
||||
try (final var connection = database.getConnection()) {
|
||||
final var sessions = new ArrayList<SessionRecord>();
|
||||
for (final var key : keys) {
|
||||
final var sessionRecord = loadSession(connection, key);
|
||||
if (sessionRecord != null) {
|
||||
sessions.add(sessionRecord);
|
||||
}
|
||||
}
|
||||
|
||||
if (sessions.size() != addresses.size()) {
|
||||
String message = "Mismatch! Asked for "
|
||||
|
@ -76,31 +104,44 @@ public class SessionStore implements SignalServiceSessionStore {
|
|||
}
|
||||
|
||||
return sessions;
|
||||
} catch (SQLException e) {
|
||||
throw new RuntimeException("Failed read from session store", e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<Integer> getSubDeviceSessions(String name) {
|
||||
final var recipientId = resolveRecipient(name);
|
||||
|
||||
synchronized (cachedSessions) {
|
||||
return getKeysLocked(recipientId).stream()
|
||||
final var recipientId = resolver.resolveRecipient(name);
|
||||
// get all sessions for recipient except primary device session
|
||||
.filter(key -> key.deviceId() != 1 && key.recipientId().equals(recipientId))
|
||||
.map(Key::deviceId)
|
||||
.toList();
|
||||
final var sql = (
|
||||
"""
|
||||
SELECT s.device_id
|
||||
FROM %s AS s
|
||||
WHERE s.account_id_type = ? AND s.recipient_id = ? AND s.device_id != 1
|
||||
"""
|
||||
).formatted(TABLE_SESSION);
|
||||
try (final var connection = database.getConnection()) {
|
||||
try (final var statement = connection.prepareStatement(sql)) {
|
||||
statement.setInt(1, accountIdType);
|
||||
statement.setLong(2, recipientId.id());
|
||||
return Utils.executeQueryForStream(statement, res -> res.getInt("device_id")).toList();
|
||||
}
|
||||
} catch (SQLException e) {
|
||||
throw new RuntimeException("Failed read from session store", e);
|
||||
}
|
||||
}
|
||||
|
||||
public boolean isCurrentRatchetKey(RecipientId recipientId, int deviceId, ECPublicKey ratchetKey) {
|
||||
final var key = new Key(recipientId, deviceId);
|
||||
|
||||
synchronized (cachedSessions) {
|
||||
final var session = loadSessionLocked(key);
|
||||
try (final var connection = database.getConnection()) {
|
||||
final var session = loadSession(connection, key);
|
||||
if (session == null) {
|
||||
return false;
|
||||
}
|
||||
return session.currentRatchetKeyMatches(ratchetKey);
|
||||
} catch (SQLException e) {
|
||||
throw new RuntimeException("Failed read from session store", e);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -108,8 +149,10 @@ public class SessionStore implements SignalServiceSessionStore {
|
|||
public void storeSession(SignalProtocolAddress address, SessionRecord session) {
|
||||
final var key = getKey(address);
|
||||
|
||||
synchronized (cachedSessions) {
|
||||
storeSessionLocked(key, session);
|
||||
try (final var connection = database.getConnection()) {
|
||||
storeSession(connection, key, session);
|
||||
} catch (SQLException e) {
|
||||
throw new RuntimeException("Failed read from session store", e);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -117,9 +160,11 @@ public class SessionStore implements SignalServiceSessionStore {
|
|||
public boolean containsSession(SignalProtocolAddress address) {
|
||||
final var key = getKey(address);
|
||||
|
||||
synchronized (cachedSessions) {
|
||||
final var session = loadSessionLocked(key);
|
||||
try (final var connection = database.getConnection()) {
|
||||
final var session = loadSession(connection, key);
|
||||
return isActive(session);
|
||||
} catch (SQLException e) {
|
||||
throw new RuntimeException("Failed read from session store", e);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -127,23 +172,24 @@ public class SessionStore implements SignalServiceSessionStore {
|
|||
public void deleteSession(SignalProtocolAddress address) {
|
||||
final var key = getKey(address);
|
||||
|
||||
synchronized (cachedSessions) {
|
||||
deleteSessionLocked(key);
|
||||
try (final var connection = database.getConnection()) {
|
||||
deleteSession(connection, key);
|
||||
} catch (SQLException e) {
|
||||
throw new RuntimeException("Failed update session store", e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void deleteAllSessions(String name) {
|
||||
final var recipientId = resolveRecipient(name);
|
||||
final var recipientId = resolver.resolveRecipient(name);
|
||||
deleteAllSessions(recipientId);
|
||||
}
|
||||
|
||||
public void deleteAllSessions(RecipientId recipientId) {
|
||||
synchronized (cachedSessions) {
|
||||
final var keys = getKeysLocked(recipientId);
|
||||
for (var key : keys) {
|
||||
deleteSessionLocked(key);
|
||||
}
|
||||
try (final var connection = database.getConnection()) {
|
||||
deleteAllSessions(connection, recipientId);
|
||||
} catch (SQLException e) {
|
||||
throw new RuntimeException("Failed update session store", e);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -151,186 +197,244 @@ public class SessionStore implements SignalServiceSessionStore {
|
|||
public void archiveSession(final SignalProtocolAddress address) {
|
||||
final var key = getKey(address);
|
||||
|
||||
synchronized (cachedSessions) {
|
||||
archiveSessionLocked(key);
|
||||
try (final var connection = database.getConnection()) {
|
||||
connection.setAutoCommit(false);
|
||||
final var session = loadSession(connection, key);
|
||||
if (session != null) {
|
||||
session.archiveCurrentState();
|
||||
storeSession(connection, key, session);
|
||||
connection.commit();
|
||||
}
|
||||
} catch (SQLException e) {
|
||||
throw new RuntimeException("Failed update session store", e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public Set<SignalProtocolAddress> getAllAddressesWithActiveSessions(final List<String> addressNames) {
|
||||
final var recipientIdToNameMap = addressNames.stream()
|
||||
.collect(Collectors.toMap(this::resolveRecipient, name -> name));
|
||||
synchronized (cachedSessions) {
|
||||
return recipientIdToNameMap.keySet()
|
||||
.collect(Collectors.toMap(resolver::resolveRecipient, name -> name));
|
||||
final var recipientIdsCommaSeparated = recipientIdToNameMap.keySet()
|
||||
.stream()
|
||||
.flatMap(recipientId -> getKeysLocked(recipientId).stream())
|
||||
.filter(key -> isActive(this.loadSessionLocked(key)))
|
||||
.map(key -> new SignalProtocolAddress(recipientIdToNameMap.get(key.recipientId), key.deviceId()))
|
||||
.map(recipientId -> String.valueOf(recipientId.id()))
|
||||
.collect(Collectors.joining(","));
|
||||
final var sql = (
|
||||
"""
|
||||
SELECT s.recipient_id, s.device_id, s.record
|
||||
FROM %s AS s
|
||||
WHERE s.account_id_type = ? AND s.recipient_id IN (%s)
|
||||
"""
|
||||
).formatted(TABLE_SESSION, recipientIdsCommaSeparated);
|
||||
try (final var connection = database.getConnection()) {
|
||||
try (final var statement = connection.prepareStatement(sql)) {
|
||||
statement.setInt(1, accountIdType);
|
||||
return Utils.executeQueryForStream(statement,
|
||||
res -> new Pair<>(getKeyFromResultSet(res), getSessionRecordFromResultSet(res)))
|
||||
.filter(pair -> isActive(pair.second()))
|
||||
.map(Pair::first)
|
||||
.map(key -> new SignalProtocolAddress(recipientIdToNameMap.get(key.recipientId),
|
||||
key.deviceId()))
|
||||
.collect(Collectors.toSet());
|
||||
}
|
||||
} catch (SQLException e) {
|
||||
throw new RuntimeException("Failed read from session store", e);
|
||||
}
|
||||
}
|
||||
|
||||
public void archiveAllSessions() {
|
||||
synchronized (cachedSessions) {
|
||||
final var keys = getKeysLocked();
|
||||
for (var key : keys) {
|
||||
archiveSessionLocked(key);
|
||||
final var sql = (
|
||||
"""
|
||||
SELECT s.recipient_id, s.device_id, s.record
|
||||
FROM %s AS s
|
||||
WHERE s.account_id_type = ?
|
||||
"""
|
||||
).formatted(TABLE_SESSION);
|
||||
try (final var connection = database.getConnection()) {
|
||||
connection.setAutoCommit(false);
|
||||
final List<Pair<Key, SessionRecord>> records;
|
||||
try (final var statement = connection.prepareStatement(sql)) {
|
||||
statement.setInt(1, accountIdType);
|
||||
records = Utils.executeQueryForStream(statement,
|
||||
res -> new Pair<>(getKeyFromResultSet(res), getSessionRecordFromResultSet(res))).toList();
|
||||
}
|
||||
for (final var record : records) {
|
||||
record.second().archiveCurrentState();
|
||||
storeSession(connection, record.first(), record.second());
|
||||
}
|
||||
connection.commit();
|
||||
} catch (SQLException e) {
|
||||
throw new RuntimeException("Failed update session store", e);
|
||||
}
|
||||
}
|
||||
|
||||
public void archiveSessions(final RecipientId recipientId) {
|
||||
synchronized (cachedSessions) {
|
||||
getKeysLocked().stream()
|
||||
.filter(key -> key.recipientId.equals(recipientId))
|
||||
.forEach(this::archiveSessionLocked);
|
||||
final var sql = (
|
||||
"""
|
||||
SELECT s.recipient_id, s.device_id, s.record
|
||||
FROM %s AS s
|
||||
WHERE s.account_id_type = ? AND s.recipient_id = ?
|
||||
"""
|
||||
).formatted(TABLE_SESSION);
|
||||
try (final var connection = database.getConnection()) {
|
||||
connection.setAutoCommit(false);
|
||||
final List<Pair<Key, SessionRecord>> records;
|
||||
try (final var statement = connection.prepareStatement(sql)) {
|
||||
statement.setInt(1, accountIdType);
|
||||
statement.setLong(2, recipientId.id());
|
||||
records = Utils.executeQueryForStream(statement,
|
||||
res -> new Pair<>(getKeyFromResultSet(res), getSessionRecordFromResultSet(res))).toList();
|
||||
}
|
||||
for (final var record : records) {
|
||||
record.second().archiveCurrentState();
|
||||
storeSession(connection, record.first(), record.second());
|
||||
}
|
||||
connection.commit();
|
||||
} catch (SQLException e) {
|
||||
throw new RuntimeException("Failed update session store", e);
|
||||
}
|
||||
}
|
||||
|
||||
public void mergeRecipients(RecipientId recipientId, RecipientId toBeMergedRecipientId) {
|
||||
try (final var connection = database.getConnection()) {
|
||||
connection.setAutoCommit(false);
|
||||
synchronized (cachedSessions) {
|
||||
final var keys = getKeysLocked(toBeMergedRecipientId);
|
||||
final var otherHasSession = keys.size() > 0;
|
||||
if (!otherHasSession) {
|
||||
return;
|
||||
cachedSessions.clear();
|
||||
}
|
||||
|
||||
final var hasSession = getKeysLocked(recipientId).size() > 0;
|
||||
if (hasSession) {
|
||||
logger.debug("To be merged recipient had sessions, deleting.");
|
||||
deleteAllSessions(toBeMergedRecipientId);
|
||||
} else {
|
||||
logger.debug("Only to be merged recipient had sessions, re-assigning to the new recipient.");
|
||||
for (var key : keys) {
|
||||
final var session = loadSessionLocked(key);
|
||||
deleteSessionLocked(key);
|
||||
if (session == null) {
|
||||
continue;
|
||||
}
|
||||
final var newKey = new Key(recipientId, key.deviceId());
|
||||
storeSessionLocked(newKey, session);
|
||||
final var sql = """
|
||||
UPDATE OR IGNORE %s
|
||||
SET recipient_id = ?
|
||||
WHERE account_id_type = ? AND recipient_id = ?
|
||||
""".formatted(TABLE_SESSION);
|
||||
try (final var statement = connection.prepareStatement(sql)) {
|
||||
statement.setLong(1, recipientId.id());
|
||||
statement.setInt(2, accountIdType);
|
||||
statement.setLong(3, toBeMergedRecipientId.id());
|
||||
final var rows = statement.executeUpdate();
|
||||
if (rows > 0) {
|
||||
logger.debug("Reassigned {} sessions of to be merged recipient.", rows);
|
||||
}
|
||||
}
|
||||
// Delete all conflicting sessions now
|
||||
deleteAllSessions(connection, toBeMergedRecipientId);
|
||||
connection.commit();
|
||||
} catch (SQLException e) {
|
||||
throw new RuntimeException("Failed update session store", e);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @param identifier can be either a serialized uuid or a e164 phone number
|
||||
*/
|
||||
private RecipientId resolveRecipient(String identifier) {
|
||||
return resolver.resolveRecipient(identifier);
|
||||
void addLegacySessions(final Collection<Pair<Key, SessionRecord>> sessions) {
|
||||
logger.debug("Migrating legacy sessions to database");
|
||||
long start = System.nanoTime();
|
||||
try (final var connection = database.getConnection()) {
|
||||
connection.setAutoCommit(false);
|
||||
for (final var pair : sessions) {
|
||||
storeSession(connection, pair.first(), pair.second());
|
||||
}
|
||||
connection.commit();
|
||||
} catch (SQLException e) {
|
||||
throw new RuntimeException("Failed update session store", e);
|
||||
}
|
||||
logger.debug("Complete sessions migration took {}ms", (System.nanoTime() - start) / 1000000);
|
||||
}
|
||||
|
||||
private Key getKey(final SignalProtocolAddress address) {
|
||||
final var recipientId = resolveRecipient(address.getName());
|
||||
final var recipientId = resolver.resolveRecipient(address.getName());
|
||||
return new Key(recipientId, address.getDeviceId());
|
||||
}
|
||||
|
||||
private List<Key> getKeysLocked(RecipientId recipientId) {
|
||||
final var files = sessionsPath.listFiles((_file, s) -> s.startsWith(recipientId.id() + "_"));
|
||||
if (files == null) {
|
||||
return List.of();
|
||||
}
|
||||
return parseFileNames(files);
|
||||
}
|
||||
|
||||
private Collection<Key> getKeysLocked() {
|
||||
final var files = sessionsPath.listFiles();
|
||||
if (files == null) {
|
||||
return List.of();
|
||||
}
|
||||
return parseFileNames(files);
|
||||
}
|
||||
|
||||
final Pattern sessionFileNamePattern = Pattern.compile("(\\d+)_(\\d+)");
|
||||
|
||||
private List<Key> parseFileNames(final File[] files) {
|
||||
return Arrays.stream(files)
|
||||
.map(f -> sessionFileNamePattern.matcher(f.getName()))
|
||||
.filter(Matcher::matches)
|
||||
.map(matcher -> {
|
||||
final var recipientId = resolver.resolveRecipient(Long.parseLong(matcher.group(1)));
|
||||
if (recipientId == null) {
|
||||
return null;
|
||||
}
|
||||
return new Key(recipientId, Integer.parseInt(matcher.group(2)));
|
||||
})
|
||||
.filter(Objects::nonNull)
|
||||
.toList();
|
||||
}
|
||||
|
||||
private File getSessionFile(Key key) {
|
||||
try {
|
||||
IOUtils.createPrivateDirectories(sessionsPath);
|
||||
} catch (IOException e) {
|
||||
throw new AssertionError("Failed to create sessions path", e);
|
||||
}
|
||||
return new File(sessionsPath, key.recipientId().id() + "_" + key.deviceId());
|
||||
}
|
||||
|
||||
private SessionRecord loadSessionLocked(final Key key) {
|
||||
{
|
||||
private SessionRecord loadSession(Connection connection, final Key key) throws SQLException {
|
||||
synchronized (cachedSessions) {
|
||||
final var session = cachedSessions.get(key);
|
||||
if (session != null) {
|
||||
return session;
|
||||
}
|
||||
}
|
||||
|
||||
final var file = getSessionFile(key);
|
||||
if (!file.exists()) {
|
||||
return null;
|
||||
final var sql = (
|
||||
"""
|
||||
SELECT s.record
|
||||
FROM %s AS s
|
||||
WHERE s.account_id_type = ? AND s.recipient_id = ? AND s.device_id = ?
|
||||
"""
|
||||
).formatted(TABLE_SESSION);
|
||||
try (final var statement = connection.prepareStatement(sql)) {
|
||||
statement.setInt(1, accountIdType);
|
||||
statement.setLong(2, key.recipientId().id());
|
||||
statement.setInt(3, key.deviceId());
|
||||
return Utils.executeQueryForOptional(statement, this::getSessionRecordFromResultSet).orElse(null);
|
||||
}
|
||||
try (var inputStream = new FileInputStream(file)) {
|
||||
final var session = new SessionRecord(inputStream.readAllBytes());
|
||||
cachedSessions.put(key, session);
|
||||
return session;
|
||||
} catch (Exception e) {
|
||||
}
|
||||
|
||||
private Key getKeyFromResultSet(ResultSet resultSet) throws SQLException {
|
||||
final var recipientId = resultSet.getLong("recipient_id");
|
||||
final var deviceId = resultSet.getInt("device_id");
|
||||
return new Key(recipientIdCreator.create(recipientId), deviceId);
|
||||
}
|
||||
|
||||
private SessionRecord getSessionRecordFromResultSet(ResultSet resultSet) throws SQLException {
|
||||
try {
|
||||
final var record = resultSet.getBytes("record");
|
||||
return new SessionRecord(record);
|
||||
} catch (InvalidMessageException e) {
|
||||
logger.warn("Failed to load session, resetting session: {}", e.getMessage());
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
private void storeSessionLocked(final Key key, final SessionRecord session) {
|
||||
private void storeSession(
|
||||
final Connection connection, final Key key, final SessionRecord session
|
||||
) throws SQLException {
|
||||
synchronized (cachedSessions) {
|
||||
cachedSessions.put(key, session);
|
||||
}
|
||||
|
||||
final var file = getSessionFile(key);
|
||||
try {
|
||||
try (var outputStream = new FileOutputStream(file)) {
|
||||
outputStream.write(session.serialize());
|
||||
}
|
||||
} catch (IOException e) {
|
||||
logger.warn("Failed to store session, trying to delete file and retry: {}", e.getMessage());
|
||||
try {
|
||||
Files.delete(file.toPath());
|
||||
try (var outputStream = new FileOutputStream(file)) {
|
||||
outputStream.write(session.serialize());
|
||||
}
|
||||
} catch (IOException e2) {
|
||||
logger.error("Failed to store session file {}: {}", file, e2.getMessage());
|
||||
}
|
||||
final var sql = """
|
||||
INSERT OR REPLACE INTO %s (account_id_type, recipient_id, device_id, record)
|
||||
VALUES (?, ?, ?, ?)
|
||||
""".formatted(TABLE_SESSION);
|
||||
try (final var statement = connection.prepareStatement(sql)) {
|
||||
statement.setInt(1, accountIdType);
|
||||
statement.setLong(2, key.recipientId().id());
|
||||
statement.setInt(3, key.deviceId());
|
||||
statement.setBytes(4, session.serialize());
|
||||
statement.executeUpdate();
|
||||
}
|
||||
}
|
||||
|
||||
private void archiveSessionLocked(final Key key) {
|
||||
final var session = loadSessionLocked(key);
|
||||
if (session == null) {
|
||||
return;
|
||||
}
|
||||
session.archiveCurrentState();
|
||||
storeSessionLocked(key, session);
|
||||
private void deleteAllSessions(final Connection connection, final RecipientId recipientId) throws SQLException {
|
||||
synchronized (cachedSessions) {
|
||||
cachedSessions.clear();
|
||||
}
|
||||
|
||||
private void deleteSessionLocked(final Key key) {
|
||||
final var sql = (
|
||||
"""
|
||||
DELETE FROM %s AS s
|
||||
WHERE s.account_id_type = ? AND s.recipient_id = ?
|
||||
"""
|
||||
).formatted(TABLE_SESSION);
|
||||
try (final var statement = connection.prepareStatement(sql)) {
|
||||
statement.setInt(1, accountIdType);
|
||||
statement.setLong(2, recipientId.id());
|
||||
statement.executeUpdate();
|
||||
}
|
||||
}
|
||||
|
||||
private void deleteSession(Connection connection, final Key key) throws SQLException {
|
||||
synchronized (cachedSessions) {
|
||||
cachedSessions.remove(key);
|
||||
|
||||
final var file = getSessionFile(key);
|
||||
if (!file.exists()) {
|
||||
return;
|
||||
}
|
||||
try {
|
||||
Files.delete(file.toPath());
|
||||
} catch (IOException e) {
|
||||
logger.error("Failed to delete session file {}: {}", file, e.getMessage());
|
||||
|
||||
final var sql = (
|
||||
"""
|
||||
DELETE FROM %s AS s
|
||||
WHERE s.account_id_type = ? AND s.recipient_id = ? AND s.device_id = ?
|
||||
"""
|
||||
).formatted(TABLE_SESSION);
|
||||
try (final var statement = connection.prepareStatement(sql)) {
|
||||
statement.setInt(1, accountIdType);
|
||||
statement.setLong(2, key.recipientId().id());
|
||||
statement.setInt(3, key.deviceId());
|
||||
statement.executeUpdate();
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -340,5 +444,5 @@ public class SessionStore implements SignalServiceSessionStore {
|
|||
&& record.getSessionVersion() == CiphertextMessage.CURRENT_VERSION;
|
||||
}
|
||||
|
||||
private record Key(RecipientId recipientId, int deviceId) {}
|
||||
record Key(RecipientId recipientId, int deviceId) {}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue