Move pre key stores to database

This commit is contained in:
AsamK 2022-06-08 12:25:42 +02:00
parent 9a698929f4
commit 7da2e1b262
8 changed files with 500 additions and 145 deletions

View file

@ -49,6 +49,10 @@ public class PreKeyHelper {
if (identityKeyPair == null) { if (identityKeyPair == null) {
return; return;
} }
final var accountId = account.getAccountId(serviceIdType);
if (accountId == null) {
return;
}
final var oneTimePreKeys = generatePreKeys(serviceIdType); final var oneTimePreKeys = generatePreKeys(serviceIdType);
final var signedPreKeyRecord = generateSignedPreKey(serviceIdType, identityKeyPair); final var signedPreKeyRecord = generateSignedPreKey(serviceIdType, identityKeyPair);

View file

@ -2,6 +2,8 @@ package org.asamk.signal.manager.storage;
import com.zaxxer.hikari.HikariDataSource; import com.zaxxer.hikari.HikariDataSource;
import org.asamk.signal.manager.storage.prekeys.PreKeyStore;
import org.asamk.signal.manager.storage.prekeys.SignedPreKeyStore;
import org.asamk.signal.manager.storage.recipients.RecipientStore; import org.asamk.signal.manager.storage.recipients.RecipientStore;
import org.asamk.signal.manager.storage.sendLog.MessageSendLogStore; import org.asamk.signal.manager.storage.sendLog.MessageSendLogStore;
import org.asamk.signal.manager.storage.stickers.StickerStore; import org.asamk.signal.manager.storage.stickers.StickerStore;
@ -15,7 +17,7 @@ import java.sql.SQLException;
public class AccountDatabase extends Database { public class AccountDatabase extends Database {
private final static Logger logger = LoggerFactory.getLogger(AccountDatabase.class); private final static Logger logger = LoggerFactory.getLogger(AccountDatabase.class);
private static final long DATABASE_VERSION = 3; private static final long DATABASE_VERSION = 4;
private AccountDatabase(final HikariDataSource dataSource) { private AccountDatabase(final HikariDataSource dataSource) {
super(logger, DATABASE_VERSION, dataSource); super(logger, DATABASE_VERSION, dataSource);
@ -30,6 +32,8 @@ public class AccountDatabase extends Database {
RecipientStore.createSql(connection); RecipientStore.createSql(connection);
MessageSendLogStore.createSql(connection); MessageSendLogStore.createSql(connection);
StickerStore.createSql(connection); StickerStore.createSql(connection);
PreKeyStore.createSql(connection);
SignedPreKeyStore.createSql(connection);
} }
@Override @Override
@ -80,5 +84,30 @@ public class AccountDatabase extends Database {
"""); """);
} }
} }
if (oldVersion < 4) {
logger.debug("Updating database: Creating pre key tables");
try (final var statement = connection.createStatement()) {
statement.executeUpdate("""
CREATE TABLE signed_pre_key (
_id INTEGER PRIMARY KEY,
account_id_type INTEGER NOT NULL,
key_id INTEGER NOT NULL,
public_key BLOB NOT NULL,
private_key BLOB NOT NULL,
signature BLOB NOT NULL,
timestamp INTEGER DEFAULT 0,
UNIQUE(account_id_type, key_id)
);
CREATE TABLE pre_key (
_id INTEGER PRIMARY KEY,
account_id_type INTEGER NOT NULL,
key_id INTEGER NOT NULL,
public_key BLOB NOT NULL,
private_key BLOB NOT NULL,
UNIQUE(account_id_type, key_id)
);
""");
}
}
} }
} }

View file

@ -18,6 +18,8 @@ import org.asamk.signal.manager.storage.identities.IdentityKeyStore;
import org.asamk.signal.manager.storage.identities.SignalIdentityKeyStore; import org.asamk.signal.manager.storage.identities.SignalIdentityKeyStore;
import org.asamk.signal.manager.storage.identities.TrustNewIdentity; import org.asamk.signal.manager.storage.identities.TrustNewIdentity;
import org.asamk.signal.manager.storage.messageCache.MessageCache; import org.asamk.signal.manager.storage.messageCache.MessageCache;
import org.asamk.signal.manager.storage.prekeys.LegacyPreKeyStore;
import org.asamk.signal.manager.storage.prekeys.LegacySignedPreKeyStore;
import org.asamk.signal.manager.storage.prekeys.PreKeyStore; import org.asamk.signal.manager.storage.prekeys.PreKeyStore;
import org.asamk.signal.manager.storage.prekeys.SignedPreKeyStore; import org.asamk.signal.manager.storage.prekeys.SignedPreKeyStore;
import org.asamk.signal.manager.storage.profiles.LegacyProfileStore; import org.asamk.signal.manager.storage.profiles.LegacyProfileStore;
@ -628,6 +630,26 @@ public class SignalAccount implements Closeable {
migratedLegacyConfig = true; migratedLegacyConfig = true;
} }
} }
final var legacyAciPreKeysPath = getAciPreKeysPath(dataPath, accountPath);
if (legacyAciPreKeysPath.exists()) {
LegacyPreKeyStore.migrate(legacyAciPreKeysPath, getAciPreKeyStore());
migratedLegacyConfig = true;
}
final var legacyPniPreKeysPath = getPniPreKeysPath(dataPath, accountPath);
if (legacyPniPreKeysPath.exists()) {
LegacyPreKeyStore.migrate(legacyPniPreKeysPath, getPniPreKeyStore());
migratedLegacyConfig = true;
}
final var legacyAciSignedPreKeysPath = getAciSignedPreKeysPath(dataPath, accountPath);
if (legacyAciSignedPreKeysPath.exists()) {
LegacySignedPreKeyStore.migrate(legacyAciSignedPreKeysPath, getAciSignedPreKeyStore());
migratedLegacyConfig = true;
}
final var legacyPniSignedPreKeysPath = getPniSignedPreKeysPath(dataPath, accountPath);
if (legacyPniSignedPreKeysPath.exists()) {
LegacySignedPreKeyStore.migrate(legacyPniSignedPreKeysPath, getPniSignedPreKeyStore());
migratedLegacyConfig = true;
}
final var legacySignalProtocolStore = rootNode.hasNonNull("axolotlStore") final var legacySignalProtocolStore = rootNode.hasNonNull("axolotlStore")
? jsonProcessor.convertValue(Utils.getNotNullNode(rootNode, "axolotlStore"), ? jsonProcessor.convertValue(Utils.getNotNullNode(rootNode, "axolotlStore"),
LegacyJsonSignalProtocolStore.class) LegacyJsonSignalProtocolStore.class)
@ -1012,7 +1034,7 @@ public class SignalAccount implements Closeable {
@Override @Override
public SignalServiceAccountDataStore get(final ServiceId accountIdentifier) { public SignalServiceAccountDataStore get(final ServiceId accountIdentifier) {
if (accountIdentifier.equals(aci)) { if (accountIdentifier.equals(aci)) {
return getSignalServiceAccountDataStore(); return getAciSignalServiceAccountDataStore();
} else if (accountIdentifier.equals(pni)) { } else if (accountIdentifier.equals(pni)) {
throw new AssertionError("PNI not to be used yet!"); throw new AssertionError("PNI not to be used yet!");
} else { } else {
@ -1022,7 +1044,7 @@ public class SignalAccount implements Closeable {
@Override @Override
public SignalServiceAccountDataStore aci() { public SignalServiceAccountDataStore aci() {
return getSignalServiceAccountDataStore(); return getAciSignalServiceAccountDataStore();
} }
@Override @Override
@ -1037,7 +1059,7 @@ public class SignalAccount implements Closeable {
}; };
} }
public SignalServiceAccountDataStore getSignalServiceAccountDataStore() { private SignalServiceAccountDataStore getAciSignalServiceAccountDataStore() {
return getOrCreate(() -> signalProtocolStore, return getOrCreate(() -> signalProtocolStore,
() -> signalProtocolStore = new SignalProtocolStore(getAciPreKeyStore(), () -> signalProtocolStore = new SignalProtocolStore(getAciPreKeyStore(),
getAciSignedPreKeyStore(), getAciSignedPreKeyStore(),
@ -1049,22 +1071,22 @@ public class SignalAccount implements Closeable {
private PreKeyStore getAciPreKeyStore() { private PreKeyStore getAciPreKeyStore() {
return getOrCreate(() -> aciPreKeyStore, return getOrCreate(() -> aciPreKeyStore,
() -> aciPreKeyStore = new PreKeyStore(getAciPreKeysPath(dataPath, accountPath))); () -> aciPreKeyStore = new PreKeyStore(getAccountDatabase(), ServiceIdType.ACI));
} }
private SignedPreKeyStore getAciSignedPreKeyStore() { private SignedPreKeyStore getAciSignedPreKeyStore() {
return getOrCreate(() -> aciSignedPreKeyStore, return getOrCreate(() -> aciSignedPreKeyStore,
() -> aciSignedPreKeyStore = new SignedPreKeyStore(getAciSignedPreKeysPath(dataPath, accountPath))); () -> aciSignedPreKeyStore = new SignedPreKeyStore(getAccountDatabase(), ServiceIdType.ACI));
} }
private PreKeyStore getPniPreKeyStore() { private PreKeyStore getPniPreKeyStore() {
return getOrCreate(() -> pniPreKeyStore, return getOrCreate(() -> pniPreKeyStore,
() -> pniPreKeyStore = new PreKeyStore(getPniPreKeysPath(dataPath, accountPath))); () -> pniPreKeyStore = new PreKeyStore(getAccountDatabase(), ServiceIdType.PNI));
} }
private SignedPreKeyStore getPniSignedPreKeyStore() { private SignedPreKeyStore getPniSignedPreKeyStore() {
return getOrCreate(() -> pniSignedPreKeyStore, return getOrCreate(() -> pniSignedPreKeyStore,
() -> pniSignedPreKeyStore = new SignedPreKeyStore(getPniSignedPreKeysPath(dataPath, accountPath))); () -> pniSignedPreKeyStore = new SignedPreKeyStore(getAccountDatabase(), ServiceIdType.PNI));
} }
public SessionStore getSessionStore() { public SessionStore getSessionStore() {
@ -1221,6 +1243,10 @@ public class SignalAccount implements Closeable {
save(); save();
} }
public ServiceId getAccountId(ServiceIdType serviceIdType) {
return serviceIdType.equals(ServiceIdType.ACI) ? aci : pni;
}
public ACI getAci() { public ACI getAci() {
return aci; return aci;
} }

View file

@ -12,6 +12,7 @@ import com.fasterxml.jackson.databind.SerializationFeature;
import org.asamk.signal.manager.storage.recipients.RecipientAddress; import org.asamk.signal.manager.storage.recipients.RecipientAddress;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.whispersystems.signalservice.api.push.ServiceIdType;
import org.whispersystems.signalservice.api.util.UuidUtil; import org.whispersystems.signalservice.api.util.UuidUtil;
import java.io.InvalidObjectException; import java.io.InvalidObjectException;
@ -62,6 +63,13 @@ public class Utils {
} }
} }
public static int getAccountIdType(ServiceIdType serviceIdType) {
return switch (serviceIdType) {
case ACI -> 0;
case PNI -> 1;
};
}
public static <T> T executeQuerySingleRow( public static <T> T executeQuerySingleRow(
PreparedStatement statement, ResultSetMapper<T> mapper PreparedStatement statement, ResultSetMapper<T> mapper
) throws SQLException { ) throws SQLException {

View file

@ -0,0 +1,61 @@
package org.asamk.signal.manager.storage.prekeys;
import org.signal.libsignal.protocol.InvalidMessageException;
import org.signal.libsignal.protocol.state.PreKeyRecord;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.nio.file.Files;
import java.util.Arrays;
import java.util.regex.Pattern;
public class LegacyPreKeyStore {
private final static Logger logger = LoggerFactory.getLogger(LegacyPreKeyStore.class);
static final Pattern preKeyFileNamePattern = Pattern.compile("(\\d+)");
public static void migrate(File preKeysPath, PreKeyStore preKeyStore) {
final var files = preKeysPath.listFiles();
if (files == null) {
return;
}
final var preKeyRecords = Arrays.stream(files)
.filter(f -> preKeyFileNamePattern.matcher(f.getName()).matches())
.map(LegacyPreKeyStore::loadPreKeyRecord)
.toList();
preKeyStore.addLegacyPreKeys(preKeyRecords);
removeAllPreKeys(preKeysPath);
}
private static void removeAllPreKeys(File preKeysPath) {
final var files = preKeysPath.listFiles();
if (files == null) {
return;
}
for (var file : files) {
try {
Files.delete(file.toPath());
} catch (IOException e) {
logger.error("Failed to delete pre key file {}: {}", file, e.getMessage());
}
}
try {
Files.delete(preKeysPath.toPath());
} catch (IOException e) {
logger.error("Failed to delete pre key directory {}: {}", preKeysPath, e.getMessage());
}
}
private static PreKeyRecord loadPreKeyRecord(final File file) {
try (var inputStream = new FileInputStream(file)) {
return new PreKeyRecord(inputStream.readAllBytes());
} catch (IOException | InvalidMessageException e) {
logger.error("Failed to load pre key: {}", e.getMessage());
throw new AssertionError(e);
}
}
}

View file

@ -0,0 +1,61 @@
package org.asamk.signal.manager.storage.prekeys;
import org.signal.libsignal.protocol.InvalidMessageException;
import org.signal.libsignal.protocol.state.SignedPreKeyRecord;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.nio.file.Files;
import java.util.Arrays;
import java.util.regex.Pattern;
public class LegacySignedPreKeyStore {
private final static Logger logger = LoggerFactory.getLogger(LegacySignedPreKeyStore.class);
static final Pattern signedPreKeyFileNamePattern = Pattern.compile("(\\d+)");
public static void migrate(File signedPreKeysPath, SignedPreKeyStore signedPreKeyStore) {
final var files = signedPreKeysPath.listFiles();
if (files == null) {
return;
}
final var signedPreKeyRecords = Arrays.stream(files)
.filter(f -> signedPreKeyFileNamePattern.matcher(f.getName()).matches())
.map(LegacySignedPreKeyStore::loadSignedPreKeyRecord)
.toList();
signedPreKeyStore.addLegacySignedPreKeys(signedPreKeyRecords);
removeAllSignedPreKeys(signedPreKeysPath);
}
private static void removeAllSignedPreKeys(File signedPreKeysPath) {
final var files = signedPreKeysPath.listFiles();
if (files == null) {
return;
}
for (var file : files) {
try {
Files.delete(file.toPath());
} catch (IOException e) {
logger.error("Failed to delete signed pre key file {}: {}", file, e.getMessage());
}
}
try {
Files.delete(signedPreKeysPath.toPath());
} catch (IOException e) {
logger.error("Failed to delete signed pre key directory {}: {}", signedPreKeysPath, e.getMessage());
}
}
private static SignedPreKeyRecord loadSignedPreKeyRecord(final File file) {
try (var inputStream = new FileInputStream(file)) {
return new SignedPreKeyRecord(inputStream.readAllBytes());
} catch (IOException | InvalidMessageException e) {
logger.error("Failed to load signed pre key: {}", e.getMessage());
throw new AssertionError(e);
}
}
}

View file

@ -1,105 +1,184 @@
package org.asamk.signal.manager.storage.prekeys; package org.asamk.signal.manager.storage.prekeys;
import org.asamk.signal.manager.util.IOUtils; import org.asamk.signal.manager.storage.Database;
import org.asamk.signal.manager.storage.Utils;
import org.signal.libsignal.protocol.InvalidKeyException;
import org.signal.libsignal.protocol.InvalidKeyIdException; import org.signal.libsignal.protocol.InvalidKeyIdException;
import org.signal.libsignal.protocol.InvalidMessageException; import org.signal.libsignal.protocol.ecc.Curve;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.signal.libsignal.protocol.state.PreKeyRecord; import org.signal.libsignal.protocol.state.PreKeyRecord;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.whispersystems.signalservice.api.push.ServiceIdType;
import java.io.File; import java.sql.Connection;
import java.io.FileInputStream; import java.sql.ResultSet;
import java.io.FileOutputStream; import java.sql.SQLException;
import java.io.IOException; import java.util.Collection;
import java.nio.file.Files;
public class PreKeyStore implements org.signal.libsignal.protocol.state.PreKeyStore { public class PreKeyStore implements org.signal.libsignal.protocol.state.PreKeyStore {
private static final String TABLE_PRE_KEY = "pre_key";
private final static Logger logger = LoggerFactory.getLogger(PreKeyStore.class); private final static Logger logger = LoggerFactory.getLogger(PreKeyStore.class);
private final File preKeysPath; private final Database database;
private final int accountIdType;
public PreKeyStore(final File preKeysPath) { public static void createSql(Connection connection) throws SQLException {
this.preKeysPath = preKeysPath; // When modifying the CREATE statement here, also add a migration in AccountDatabase.java
try (final var statement = connection.createStatement()) {
statement.executeUpdate("""
CREATE TABLE pre_key (
_id INTEGER PRIMARY KEY,
account_id_type INTEGER NOT NULL,
key_id INTEGER NOT NULL,
public_key BLOB NOT NULL,
private_key BLOB NOT NULL,
UNIQUE(account_id_type, key_id)
);
""");
}
}
public PreKeyStore(final Database database, final ServiceIdType serviceIdType) {
this.database = database;
this.accountIdType = Utils.getAccountIdType(serviceIdType);
} }
@Override @Override
public PreKeyRecord loadPreKey(int preKeyId) throws InvalidKeyIdException { public PreKeyRecord loadPreKey(int preKeyId) throws InvalidKeyIdException {
final var file = getPreKeyFile(preKeyId); final var preKey = getPreKey(preKeyId);
if (preKey == null) {
if (!file.exists()) { throw new InvalidKeyIdException("No such signed pre key record!");
throw new InvalidKeyIdException("No such pre key record!");
}
try (var inputStream = new FileInputStream(file)) {
return new PreKeyRecord(inputStream.readAllBytes());
} catch (IOException | InvalidMessageException e) {
logger.error("Failed to load pre key: {}", e.getMessage());
throw new AssertionError(e);
} }
return preKey;
} }
@Override @Override
public void storePreKey(int preKeyId, PreKeyRecord record) { public void storePreKey(int preKeyId, PreKeyRecord record) {
final var file = getPreKeyFile(preKeyId); final var sql = (
try { """
try (var outputStream = new FileOutputStream(file)) { INSERT INTO %s (account_id_type, key_id, public_key, private_key)
outputStream.write(record.serialize()); VALUES (?, ?, ?, ?)
} """
} catch (IOException e) { ).formatted(TABLE_PRE_KEY);
logger.warn("Failed to store pre key, trying to delete file and retry: {}", e.getMessage()); try (final var connection = database.getConnection()) {
try { try (final var statement = connection.prepareStatement(sql)) {
Files.delete(file.toPath()); statement.setInt(1, accountIdType);
try (var outputStream = new FileOutputStream(file)) { statement.setLong(2, preKeyId);
outputStream.write(record.serialize()); final var keyPair = record.getKeyPair();
} statement.setBytes(3, keyPair.getPublicKey().serialize());
} catch (IOException e2) { statement.setBytes(4, keyPair.getPrivateKey().serialize());
logger.error("Failed to store pre key file {}: {}", file, e2.getMessage()); statement.executeUpdate();
} catch (InvalidKeyException ignored) {
} }
} catch (SQLException e) {
throw new RuntimeException("Failed update pre_key store", e);
} }
} }
@Override @Override
public boolean containsPreKey(int preKeyId) { public boolean containsPreKey(int preKeyId) {
final var file = getPreKeyFile(preKeyId); return getPreKey(preKeyId) != null;
return file.exists();
} }
@Override @Override
public void removePreKey(int preKeyId) { public void removePreKey(int preKeyId) {
final var file = getPreKeyFile(preKeyId); final var sql = (
"""
if (!file.exists()) { DELETE FROM %s AS p
return; WHERE p.account_id_type = ? AND p.key_id = ?
"""
).formatted(TABLE_PRE_KEY);
try (final var connection = database.getConnection()) {
try (final var statement = connection.prepareStatement(sql)) {
statement.setInt(1, accountIdType);
statement.setLong(2, preKeyId);
statement.executeUpdate();
} }
try { } catch (SQLException e) {
Files.delete(file.toPath()); throw new RuntimeException("Failed update pre_key store", e);
} catch (IOException e) {
logger.error("Failed to delete pre key file {}: {}", file, e.getMessage());
} }
} }
public void removeAllPreKeys() { public void removeAllPreKeys() {
final var files = preKeysPath.listFiles(); final var sql = (
if (files == null) { """
return; DELETE FROM %s AS p
} WHERE p.account_id_type = ?
"""
for (var file : files) { ).formatted(TABLE_PRE_KEY);
try { try (final var connection = database.getConnection()) {
Files.delete(file.toPath()); try (final var statement = connection.prepareStatement(sql)) {
} catch (IOException e) { statement.setInt(1, accountIdType);
logger.error("Failed to delete pre key file {}: {}", file, e.getMessage()); statement.executeUpdate();
} }
} catch (SQLException e) {
throw new RuntimeException("Failed update pre_key store", e);
} }
} }
private File getPreKeyFile(int preKeyId) { void addLegacyPreKeys(final Collection<PreKeyRecord> preKeys) {
try { logger.debug("Migrating legacy preKeys to database");
IOUtils.createPrivateDirectories(preKeysPath); long start = System.nanoTime();
} catch (IOException e) { final var sql = (
throw new AssertionError("Failed to create pre keys path", e); """
INSERT INTO %s (account_id_type, key_id, public_key, private_key)
VALUES (?, ?, ?, ?)
"""
).formatted(TABLE_PRE_KEY);
try (final var connection = database.getConnection()) {
connection.setAutoCommit(false);
final var deleteSql = "DELETE FROM %s AS p WHERE p.account_id_type = ?".formatted(TABLE_PRE_KEY);
try (final var statement = connection.prepareStatement(deleteSql)) {
statement.setInt(1, accountIdType);
statement.executeUpdate();
}
try (final var statement = connection.prepareStatement(sql)) {
for (final var record : preKeys) {
statement.setInt(1, accountIdType);
statement.setLong(2, record.getId());
final var keyPair = record.getKeyPair();
statement.setBytes(3, keyPair.getPublicKey().serialize());
statement.setBytes(4, keyPair.getPrivateKey().serialize());
statement.executeUpdate();
}
} catch (InvalidKeyException ignored) {
}
connection.commit();
} catch (SQLException e) {
throw new RuntimeException("Failed update preKey store", e);
}
logger.debug("Complete preKeys migration took {}ms", (System.nanoTime() - start) / 1000000);
}
private PreKeyRecord getPreKey(int preKeyId) {
final var sql = (
"""
SELECT p.key_id, p.public_key, p.private_key
FROM %s p
WHERE p.account_id_type = ? AND p.key_id = ?
"""
).formatted(TABLE_PRE_KEY);
try (final var connection = database.getConnection()) {
try (final var statement = connection.prepareStatement(sql)) {
statement.setInt(1, accountIdType);
statement.setLong(2, preKeyId);
return Utils.executeQueryForOptional(statement, this::getPreKeyRecordFromResultSet).orElse(null);
}
} catch (SQLException e) {
throw new RuntimeException("Failed read from pre_key store", e);
}
}
private PreKeyRecord getPreKeyRecordFromResultSet(ResultSet resultSet) throws SQLException {
try {
final var keyId = resultSet.getInt("key_id");
final var publicKey = Curve.decodePoint(resultSet.getBytes("public_key"), 0);
final var privateKey = Curve.decodePrivatePoint(resultSet.getBytes("private_key"));
return new PreKeyRecord(keyId, new ECKeyPair(publicKey, privateKey));
} catch (InvalidKeyException e) {
return null;
} }
return new File(preKeysPath, String.valueOf(preKeyId));
} }
} }

View file

@ -1,126 +1,213 @@
package org.asamk.signal.manager.storage.prekeys; package org.asamk.signal.manager.storage.prekeys;
import org.asamk.signal.manager.util.IOUtils; import org.asamk.signal.manager.storage.Database;
import org.asamk.signal.manager.storage.Utils;
import org.signal.libsignal.protocol.InvalidKeyException;
import org.signal.libsignal.protocol.InvalidKeyIdException; import org.signal.libsignal.protocol.InvalidKeyIdException;
import org.signal.libsignal.protocol.InvalidMessageException; import org.signal.libsignal.protocol.ecc.Curve;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.signal.libsignal.protocol.state.SignedPreKeyRecord; import org.signal.libsignal.protocol.state.SignedPreKeyRecord;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.whispersystems.signalservice.api.push.ServiceIdType;
import java.io.File; import java.sql.Connection;
import java.io.FileInputStream; import java.sql.ResultSet;
import java.io.FileOutputStream; import java.sql.SQLException;
import java.io.IOException; import java.util.Collection;
import java.nio.file.Files;
import java.util.Arrays;
import java.util.List; import java.util.List;
import java.util.regex.Pattern; import java.util.Objects;
public class SignedPreKeyStore implements org.signal.libsignal.protocol.state.SignedPreKeyStore { public class SignedPreKeyStore implements org.signal.libsignal.protocol.state.SignedPreKeyStore {
private static final String TABLE_SIGNED_PRE_KEY = "signed_pre_key";
private final static Logger logger = LoggerFactory.getLogger(SignedPreKeyStore.class); private final static Logger logger = LoggerFactory.getLogger(SignedPreKeyStore.class);
private final File signedPreKeysPath; private final Database database;
private final int accountIdType;
public SignedPreKeyStore(final File signedPreKeysPath) { public static void createSql(Connection connection) throws SQLException {
this.signedPreKeysPath = signedPreKeysPath; // When modifying the CREATE statement here, also add a migration in AccountDatabase.java
try (final var statement = connection.createStatement()) {
statement.executeUpdate("""
CREATE TABLE signed_pre_key (
_id INTEGER PRIMARY KEY,
account_id_type BLOB NOT NULL,
key_id INTEGER NOT NULL,
public_key BLOB NOT NULL,
private_key BLOB NOT NULL,
signature BLOB NOT NULL,
timestamp INTEGER DEFAULT 0,
UNIQUE(account_id_type, key_id)
);
""");
}
}
public SignedPreKeyStore(final Database database, final ServiceIdType serviceIdType) {
this.database = database;
this.accountIdType = Utils.getAccountIdType(serviceIdType);
} }
@Override @Override
public SignedPreKeyRecord loadSignedPreKey(int signedPreKeyId) throws InvalidKeyIdException { public SignedPreKeyRecord loadSignedPreKey(int signedPreKeyId) throws InvalidKeyIdException {
final var file = getSignedPreKeyFile(signedPreKeyId); final SignedPreKeyRecord signedPreKeyRecord = getSignedPreKey(signedPreKeyId);
if (signedPreKeyRecord == null) {
if (!file.exists()) {
throw new InvalidKeyIdException("No such signed pre key record!"); throw new InvalidKeyIdException("No such signed pre key record!");
} }
return loadSignedPreKeyRecord(file); return signedPreKeyRecord;
} }
final Pattern signedPreKeyFileNamePattern = Pattern.compile("(\\d+)");
@Override @Override
public List<SignedPreKeyRecord> loadSignedPreKeys() { public List<SignedPreKeyRecord> loadSignedPreKeys() {
final var files = signedPreKeysPath.listFiles(); final var sql = (
if (files == null) { """
return List.of(); SELECT p.key_id, p.public_key, p.private_key, p.signature, p.timestamp
} FROM %s p
return Arrays.stream(files) WHERE p.account_id_type = ?
.filter(f -> signedPreKeyFileNamePattern.matcher(f.getName()).matches()) """
.map(this::loadSignedPreKeyRecord) ).formatted(TABLE_SIGNED_PRE_KEY);
try (final var connection = database.getConnection()) {
try (final var statement = connection.prepareStatement(sql)) {
statement.setInt(1, accountIdType);
return Utils.executeQueryForStream(statement, this::getSignedPreKeyRecordFromResultSet)
.filter(Objects::nonNull)
.toList(); .toList();
} }
} catch (SQLException e) {
throw new RuntimeException("Failed read from signed_pre_key store", e);
}
}
@Override @Override
public void storeSignedPreKey(int signedPreKeyId, SignedPreKeyRecord record) { public void storeSignedPreKey(int signedPreKeyId, SignedPreKeyRecord record) {
final var file = getSignedPreKeyFile(signedPreKeyId); final var sql = (
try { """
try (var outputStream = new FileOutputStream(file)) { INSERT INTO %s (account_id_type, key_id, public_key, private_key, signature, timestamp)
outputStream.write(record.serialize()); VALUES (?, ?, ?, ?, ?, ?)
} """
} catch (IOException e) { ).formatted(TABLE_SIGNED_PRE_KEY);
logger.warn("Failed to store signed pre key, trying to delete file and retry: {}", e.getMessage()); try (final var connection = database.getConnection()) {
try { try (final var statement = connection.prepareStatement(sql)) {
Files.delete(file.toPath()); statement.setInt(1, accountIdType);
try (var outputStream = new FileOutputStream(file)) { statement.setLong(2, signedPreKeyId);
outputStream.write(record.serialize()); final var keyPair = record.getKeyPair();
} statement.setBytes(3, keyPair.getPublicKey().serialize());
} catch (IOException e2) { statement.setBytes(4, keyPair.getPrivateKey().serialize());
logger.error("Failed to store signed pre key file {}: {}", file, e2.getMessage()); statement.setBytes(5, record.getSignature());
statement.setLong(6, record.getTimestamp());
statement.executeUpdate();
} }
} catch (SQLException e) {
throw new RuntimeException("Failed update signed_pre_key store", e);
} }
} }
@Override @Override
public boolean containsSignedPreKey(int signedPreKeyId) { public boolean containsSignedPreKey(int signedPreKeyId) {
final var file = getSignedPreKeyFile(signedPreKeyId); return getSignedPreKey(signedPreKeyId) != null;
return file.exists();
} }
@Override @Override
public void removeSignedPreKey(int signedPreKeyId) { public void removeSignedPreKey(int signedPreKeyId) {
final var file = getSignedPreKeyFile(signedPreKeyId); final var sql = (
"""
if (!file.exists()) { DELETE FROM %s AS p
return; WHERE p.account_id_type = ? AND p.key_id = ?
"""
).formatted(TABLE_SIGNED_PRE_KEY);
try (final var connection = database.getConnection()) {
try (final var statement = connection.prepareStatement(sql)) {
statement.setInt(1, accountIdType);
statement.setLong(2, signedPreKeyId);
statement.executeUpdate();
} }
try { } catch (SQLException e) {
Files.delete(file.toPath()); throw new RuntimeException("Failed update signed_pre_key store", e);
} catch (IOException e) {
logger.error("Failed to delete signed pre key file {}: {}", file, e.getMessage());
} }
} }
public void removeAllSignedPreKeys() { public void removeAllSignedPreKeys() {
final var files = signedPreKeysPath.listFiles(); final var sql = (
if (files == null) { """
return; DELETE FROM %s AS p
WHERE p.account_id_type = ?
"""
).formatted(TABLE_SIGNED_PRE_KEY);
try (final var connection = database.getConnection()) {
try (final var statement = connection.prepareStatement(sql)) {
statement.setInt(1, accountIdType);
statement.executeUpdate();
}
} catch (SQLException e) {
throw new RuntimeException("Failed update signed_pre_key store", e);
}
} }
for (var file : files) { void addLegacySignedPreKeys(final Collection<SignedPreKeyRecord> signedPreKeys) {
logger.debug("Migrating legacy signedPreKeys to database");
long start = System.nanoTime();
final var sql = (
"""
INSERT INTO %s (account_id_type, key_id, public_key, private_key, signature, timestamp)
VALUES (?, ?, ?, ?, ?, ?)
"""
).formatted(TABLE_SIGNED_PRE_KEY);
try (final var connection = database.getConnection()) {
connection.setAutoCommit(false);
final var deleteSql = "DELETE FROM %s AS p WHERE p.account_id_type = ?".formatted(TABLE_SIGNED_PRE_KEY);
try (final var statement = connection.prepareStatement(deleteSql)) {
statement.setInt(1, accountIdType);
statement.executeUpdate();
}
try (final var statement = connection.prepareStatement(sql)) {
for (final var record : signedPreKeys) {
statement.setInt(1, accountIdType);
statement.setLong(2, record.getId());
final var keyPair = record.getKeyPair();
statement.setBytes(3, keyPair.getPublicKey().serialize());
statement.setBytes(4, keyPair.getPrivateKey().serialize());
statement.setBytes(5, record.getSignature());
statement.setLong(6, record.getTimestamp());
statement.executeUpdate();
}
}
connection.commit();
} catch (SQLException e) {
throw new RuntimeException("Failed update signedPreKey store", e);
}
logger.debug("Complete signedPreKeys migration took {}ms", (System.nanoTime() - start) / 1000000);
}
private SignedPreKeyRecord getSignedPreKey(int signedPreKeyId) {
final var sql = (
"""
SELECT p.key_id, p.public_key, p.private_key, p.signature, p.timestamp
FROM %s p
WHERE p.account_id_type = ? AND p.key_id = ?
"""
).formatted(TABLE_SIGNED_PRE_KEY);
try (final var connection = database.getConnection()) {
try (final var statement = connection.prepareStatement(sql)) {
statement.setInt(1, accountIdType);
statement.setLong(2, signedPreKeyId);
return Utils.executeQueryForOptional(statement, this::getSignedPreKeyRecordFromResultSet).orElse(null);
}
} catch (SQLException e) {
throw new RuntimeException("Failed read from signed_pre_key store", e);
}
}
private SignedPreKeyRecord getSignedPreKeyRecordFromResultSet(ResultSet resultSet) throws SQLException {
try { try {
Files.delete(file.toPath()); final var keyId = resultSet.getInt("key_id");
} catch (IOException e) { final var publicKey = Curve.decodePoint(resultSet.getBytes("public_key"), 0);
logger.error("Failed to delete signed pre key file {}: {}", file, e.getMessage()); final var privateKey = Curve.decodePrivatePoint(resultSet.getBytes("private_key"));
} final var signature = resultSet.getBytes("signature");
} final var timestamp = resultSet.getLong("timestamp");
} return new SignedPreKeyRecord(keyId, timestamp, new ECKeyPair(publicKey, privateKey), signature);
} catch (InvalidKeyException e) {
private File getSignedPreKeyFile(int signedPreKeyId) { return null;
try {
IOUtils.createPrivateDirectories(signedPreKeysPath);
} catch (IOException e) {
throw new AssertionError("Failed to create signed pre keys path", e);
}
return new File(signedPreKeysPath, String.valueOf(signedPreKeyId));
}
private SignedPreKeyRecord loadSignedPreKeyRecord(final File file) {
try (var inputStream = new FileInputStream(file)) {
return new SignedPreKeyRecord(inputStream.readAllBytes());
} catch (IOException | InvalidMessageException e) {
logger.error("Failed to load signed pre key: {}", e.getMessage());
throw new AssertionError(e);
} }
} }
} }