Refactor signed pre key store

This commit is contained in:
AsamK 2021-04-17 16:06:35 +02:00
parent ccc380f575
commit afb22deada
7 changed files with 204 additions and 132 deletions

View file

@ -16,6 +16,7 @@ import org.asamk.signal.manager.storage.groups.GroupInfoV1;
import org.asamk.signal.manager.storage.groups.JsonGroupStore; import org.asamk.signal.manager.storage.groups.JsonGroupStore;
import org.asamk.signal.manager.storage.messageCache.MessageCache; import org.asamk.signal.manager.storage.messageCache.MessageCache;
import org.asamk.signal.manager.storage.prekeys.PreKeyStore; import org.asamk.signal.manager.storage.prekeys.PreKeyStore;
import org.asamk.signal.manager.storage.prekeys.SignedPreKeyStore;
import org.asamk.signal.manager.storage.profiles.ProfileStore; import org.asamk.signal.manager.storage.profiles.ProfileStore;
import org.asamk.signal.manager.storage.protocol.JsonSignalProtocolStore; import org.asamk.signal.manager.storage.protocol.JsonSignalProtocolStore;
import org.asamk.signal.manager.storage.protocol.SignalServiceAddressResolver; import org.asamk.signal.manager.storage.protocol.SignalServiceAddressResolver;
@ -82,6 +83,7 @@ public class SignalAccount implements Closeable {
private JsonSignalProtocolStore signalProtocolStore; private JsonSignalProtocolStore signalProtocolStore;
private PreKeyStore preKeyStore; private PreKeyStore preKeyStore;
private SignedPreKeyStore signedPreKeyStore;
private SessionStore sessionStore; private SessionStore sessionStore;
private JsonGroupStore groupStore; private JsonGroupStore groupStore;
private JsonContactsStore contactStore; private JsonContactsStore contactStore;
@ -136,11 +138,13 @@ public class SignalAccount implements Closeable {
account.recipientStore = RecipientStore.load(getRecipientsStoreFile(dataPath, username), account.recipientStore = RecipientStore.load(getRecipientsStoreFile(dataPath, username),
account::mergeRecipients); account::mergeRecipients);
account.preKeyStore = new PreKeyStore(getPreKeysPath(dataPath, username)); account.preKeyStore = new PreKeyStore(getPreKeysPath(dataPath, username));
account.signedPreKeyStore = new SignedPreKeyStore(getSignedPreKeysPath(dataPath, username));
account.sessionStore = new SessionStore(getSessionsPath(dataPath, username), account.sessionStore = new SessionStore(getSessionsPath(dataPath, username),
account.recipientStore::resolveRecipient); account.recipientStore::resolveRecipient);
account.signalProtocolStore = new JsonSignalProtocolStore(identityKey, account.signalProtocolStore = new JsonSignalProtocolStore(identityKey,
registrationId, registrationId,
account.preKeyStore, account.preKeyStore,
account.signedPreKeyStore,
account.sessionStore); account.sessionStore);
account.profileStore = new ProfileStore(); account.profileStore = new ProfileStore();
account.stickerStore = new StickerStore(); account.stickerStore = new StickerStore();
@ -183,11 +187,13 @@ public class SignalAccount implements Closeable {
account.recipientStore = RecipientStore.load(getRecipientsStoreFile(dataPath, username), account.recipientStore = RecipientStore.load(getRecipientsStoreFile(dataPath, username),
account::mergeRecipients); account::mergeRecipients);
account.preKeyStore = new PreKeyStore(getPreKeysPath(dataPath, username)); account.preKeyStore = new PreKeyStore(getPreKeysPath(dataPath, username));
account.signedPreKeyStore = new SignedPreKeyStore(getSignedPreKeysPath(dataPath, username));
account.sessionStore = new SessionStore(getSessionsPath(dataPath, username), account.sessionStore = new SessionStore(getSessionsPath(dataPath, username),
account.recipientStore::resolveRecipient); account.recipientStore::resolveRecipient);
account.signalProtocolStore = new JsonSignalProtocolStore(identityKey, account.signalProtocolStore = new JsonSignalProtocolStore(identityKey,
registrationId, registrationId,
account.preKeyStore, account.preKeyStore,
account.signedPreKeyStore,
account.sessionStore); account.sessionStore);
account.profileStore = new ProfileStore(); account.profileStore = new ProfileStore();
account.stickerStore = new StickerStore(); account.stickerStore = new StickerStore();
@ -251,6 +257,10 @@ public class SignalAccount implements Closeable {
return new File(getUserPath(dataPath, username), "pre-keys"); return new File(getUserPath(dataPath, username), "pre-keys");
} }
private static File getSignedPreKeysPath(File dataPath, String username) {
return new File(getUserPath(dataPath, username), "signed-pre-keys");
}
private static File getSessionsPath(File dataPath, String username) { private static File getSessionsPath(File dataPath, String username) {
return new File(getUserPath(dataPath, username), "sessions"); return new File(getUserPath(dataPath, username), "sessions");
} }
@ -330,6 +340,7 @@ public class SignalAccount implements Closeable {
signalProtocolStore = jsonProcessor.convertValue(Utils.getNotNullNode(rootNode, "axolotlStore"), signalProtocolStore = jsonProcessor.convertValue(Utils.getNotNullNode(rootNode, "axolotlStore"),
JsonSignalProtocolStore.class); JsonSignalProtocolStore.class);
preKeyStore = new PreKeyStore(getPreKeysPath(dataPath, username)); preKeyStore = new PreKeyStore(getPreKeysPath(dataPath, username));
if (signalProtocolStore.getLegacyPreKeyStore() != null) { if (signalProtocolStore.getLegacyPreKeyStore() != null) {
logger.debug("Migrating legacy pre key store."); logger.debug("Migrating legacy pre key store.");
@ -342,6 +353,20 @@ public class SignalAccount implements Closeable {
} }
} }
signalProtocolStore.setPreKeyStore(preKeyStore); signalProtocolStore.setPreKeyStore(preKeyStore);
signedPreKeyStore = new SignedPreKeyStore(getSignedPreKeysPath(dataPath, username));
if (signalProtocolStore.getLegacySignedPreKeyStore() != null) {
logger.debug("Migrating legacy signed pre key store.");
for (var entry : signalProtocolStore.getLegacySignedPreKeyStore().getSignedPreKeys().entrySet()) {
try {
signedPreKeyStore.storeSignedPreKey(entry.getKey(), new SignedPreKeyRecord(entry.getValue()));
} catch (IOException e) {
logger.warn("Failed to migrate signed pre key, ignoring", e);
}
}
}
signalProtocolStore.setSignedPreKeyStore(signedPreKeyStore);
sessionStore = new SessionStore(getSessionsPath(dataPath, username), recipientStore::resolveRecipient); sessionStore = new SessionStore(getSessionsPath(dataPath, username), recipientStore::resolveRecipient);
if (signalProtocolStore.getLegacySessionStore() != null) { if (signalProtocolStore.getLegacySessionStore() != null) {
logger.debug("Migrating legacy session store."); logger.debug("Migrating legacy session store.");
@ -355,6 +380,7 @@ public class SignalAccount implements Closeable {
} }
} }
signalProtocolStore.setSessionStore(sessionStore); signalProtocolStore.setSessionStore(sessionStore);
registered = Utils.getNotNullNode(rootNode, "registered").asBoolean(); registered = Utils.getNotNullNode(rootNode, "registered").asBoolean();
var groupStoreNode = rootNode.get("groupStore"); var groupStoreNode = rootNode.get("groupStore");
if (groupStoreNode != null) { if (groupStoreNode != null) {

View file

@ -0,0 +1,111 @@
package org.asamk.signal.manager.storage.prekeys;
import org.asamk.signal.manager.util.IOUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.libsignal.InvalidKeyIdException;
import org.whispersystems.libsignal.state.SignedPreKeyRecord;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.nio.file.Files;
import java.util.Arrays;
import java.util.List;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
public class SignedPreKeyStore implements org.whispersystems.libsignal.state.SignedPreKeyStore {
private final static Logger logger = LoggerFactory.getLogger(SignedPreKeyStore.class);
private final File signedPreKeysPath;
public SignedPreKeyStore(final File signedPreKeysPath) {
this.signedPreKeysPath = signedPreKeysPath;
}
@Override
public SignedPreKeyRecord loadSignedPreKey(int signedPreKeyId) throws InvalidKeyIdException {
final var file = getSignedPreKeyFile(signedPreKeyId);
if (!file.exists()) {
throw new InvalidKeyIdException("No such signed pre key record!");
}
return loadSignedPreKeyRecord(file);
}
final Pattern signedPreKeyFileNamePattern = Pattern.compile("([0-9]+)");
@Override
public List<SignedPreKeyRecord> loadSignedPreKeys() {
final var files = signedPreKeysPath.listFiles();
if (files == null) {
return List.of();
}
return Arrays.stream(files)
.filter(f -> signedPreKeyFileNamePattern.matcher(f.getName()).matches())
.map(this::loadSignedPreKeyRecord)
.collect(Collectors.toList());
}
@Override
public void storeSignedPreKey(int signedPreKeyId, SignedPreKeyRecord record) {
final var file = getSignedPreKeyFile(signedPreKeyId);
try {
try (var outputStream = new FileOutputStream(file)) {
outputStream.write(record.serialize());
}
} catch (IOException e) {
logger.warn("Failed to store signed pre key, trying to delete file and retry: {}", e.getMessage());
try {
Files.delete(file.toPath());
try (var outputStream = new FileOutputStream(file)) {
outputStream.write(record.serialize());
}
} catch (IOException e2) {
logger.error("Failed to store signed pre key file {}: {}", file, e2.getMessage());
}
}
}
@Override
public boolean containsSignedPreKey(int signedPreKeyId) {
final var file = getSignedPreKeyFile(signedPreKeyId);
return file.exists();
}
@Override
public void removeSignedPreKey(int signedPreKeyId) {
final var file = getSignedPreKeyFile(signedPreKeyId);
if (!file.exists()) {
return;
}
try {
Files.delete(file.toPath());
} catch (IOException e) {
logger.error("Failed to delete signed pre key file {}: {}", file, e.getMessage());
}
}
private File getSignedPreKeyFile(int signedPreKeyId) {
try {
IOUtils.createPrivateDirectories(signedPreKeysPath);
} catch (IOException e) {
throw new AssertionError("Failed to create signed pre keys path", e);
}
return new File(signedPreKeysPath, String.valueOf(signedPreKeyId));
}
private SignedPreKeyRecord loadSignedPreKeyRecord(final File file) {
try (var inputStream = new FileInputStream(file)) {
return new SignedPreKeyRecord(inputStream.readAllBytes());
} catch (IOException e) {
logger.error("Failed to load signed pre key: {}", e.getMessage());
throw new AssertionError(e);
}
}
}

View file

@ -14,13 +14,14 @@ import org.whispersystems.libsignal.state.PreKeyRecord;
import org.whispersystems.libsignal.state.PreKeyStore; import org.whispersystems.libsignal.state.PreKeyStore;
import org.whispersystems.libsignal.state.SessionRecord; import org.whispersystems.libsignal.state.SessionRecord;
import org.whispersystems.libsignal.state.SignedPreKeyRecord; import org.whispersystems.libsignal.state.SignedPreKeyRecord;
import org.whispersystems.libsignal.state.SignedPreKeyStore;
import org.whispersystems.signalservice.api.SignalServiceProtocolStore; import org.whispersystems.signalservice.api.SignalServiceProtocolStore;
import org.whispersystems.signalservice.api.SignalServiceSessionStore; import org.whispersystems.signalservice.api.SignalServiceSessionStore;
import org.whispersystems.signalservice.api.push.SignalServiceAddress; import org.whispersystems.signalservice.api.push.SignalServiceAddress;
import java.util.List; import java.util.List;
@JsonIgnoreProperties(value = {"sessionStore", "preKeys"}, allowSetters = true) @JsonIgnoreProperties(value = {"sessionStore", "preKeys", "signedPreKeyStore"}, allowSetters = true)
public class JsonSignalProtocolStore implements SignalServiceProtocolStore { public class JsonSignalProtocolStore implements SignalServiceProtocolStore {
@JsonProperty("preKeys") @JsonProperty("preKeys")
@ -32,9 +33,8 @@ public class JsonSignalProtocolStore implements SignalServiceProtocolStore {
private LegacyJsonSessionStore legacySessionStore; private LegacyJsonSessionStore legacySessionStore;
@JsonProperty("signedPreKeyStore") @JsonProperty("signedPreKeyStore")
@JsonDeserialize(using = JsonSignedPreKeyStore.JsonSignedPreKeyStoreDeserializer.class) @JsonDeserialize(using = LegacyJsonSignedPreKeyStore.JsonSignedPreKeyStoreDeserializer.class)
@JsonSerialize(using = JsonSignedPreKeyStore.JsonSignedPreKeyStoreSerializer.class) private LegacyJsonSignedPreKeyStore legacySignedPreKeyStore;
private JsonSignedPreKeyStore signedPreKeyStore;
@JsonProperty("identityKeyStore") @JsonProperty("identityKeyStore")
@JsonDeserialize(using = JsonIdentityKeyStore.JsonIdentityKeyStoreDeserializer.class) @JsonDeserialize(using = JsonIdentityKeyStore.JsonIdentityKeyStoreDeserializer.class)
@ -42,6 +42,7 @@ public class JsonSignalProtocolStore implements SignalServiceProtocolStore {
private JsonIdentityKeyStore identityKeyStore; private JsonIdentityKeyStore identityKeyStore;
private PreKeyStore preKeyStore; private PreKeyStore preKeyStore;
private SignedPreKeyStore signedPreKeyStore;
private SignalServiceSessionStore sessionStore; private SignalServiceSessionStore sessionStore;
public JsonSignalProtocolStore() { public JsonSignalProtocolStore() {
@ -51,11 +52,12 @@ public class JsonSignalProtocolStore implements SignalServiceProtocolStore {
IdentityKeyPair identityKeyPair, IdentityKeyPair identityKeyPair,
int registrationId, int registrationId,
PreKeyStore preKeyStore, PreKeyStore preKeyStore,
SignedPreKeyStore signedPreKeyStore,
SignalServiceSessionStore sessionStore SignalServiceSessionStore sessionStore
) { ) {
this.preKeyStore = preKeyStore; this.preKeyStore = preKeyStore;
this.signedPreKeyStore = signedPreKeyStore;
this.sessionStore = sessionStore; this.sessionStore = sessionStore;
signedPreKeyStore = new JsonSignedPreKeyStore();
this.identityKeyStore = new JsonIdentityKeyStore(identityKeyPair, registrationId); this.identityKeyStore = new JsonIdentityKeyStore(identityKeyPair, registrationId);
} }
@ -67,6 +69,10 @@ public class JsonSignalProtocolStore implements SignalServiceProtocolStore {
this.preKeyStore = preKeyStore; this.preKeyStore = preKeyStore;
} }
public void setSignedPreKeyStore(final SignedPreKeyStore signedPreKeyStore) {
this.signedPreKeyStore = signedPreKeyStore;
}
public void setSessionStore(final SignalServiceSessionStore sessionStore) { public void setSessionStore(final SignalServiceSessionStore sessionStore) {
this.sessionStore = sessionStore; this.sessionStore = sessionStore;
} }
@ -75,6 +81,10 @@ public class JsonSignalProtocolStore implements SignalServiceProtocolStore {
return legacyPreKeyStore; return legacyPreKeyStore;
} }
public LegacyJsonSignedPreKeyStore getLegacySignedPreKeyStore() {
return legacySignedPreKeyStore;
}
public LegacyJsonSessionStore getLegacySessionStore() { public LegacyJsonSessionStore getLegacySessionStore() {
return legacySessionStore; return legacySessionStore;
} }

View file

@ -1,121 +0,0 @@
package org.asamk.signal.manager.storage.protocol;
import com.fasterxml.jackson.core.JsonGenerator;
import com.fasterxml.jackson.core.JsonParser;
import com.fasterxml.jackson.databind.DeserializationContext;
import com.fasterxml.jackson.databind.JsonDeserializer;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.JsonSerializer;
import com.fasterxml.jackson.databind.SerializerProvider;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.libsignal.InvalidKeyIdException;
import org.whispersystems.libsignal.state.SignedPreKeyRecord;
import org.whispersystems.libsignal.state.SignedPreKeyStore;
import java.io.IOException;
import java.util.Base64;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
class JsonSignedPreKeyStore implements SignedPreKeyStore {
private final static Logger logger = LoggerFactory.getLogger(JsonSignedPreKeyStore.class);
private final Map<Integer, byte[]> store = new HashMap<>();
public JsonSignedPreKeyStore() {
}
private void addSignedPreKeys(Map<Integer, byte[]> preKeys) {
store.putAll(preKeys);
}
@Override
public SignedPreKeyRecord loadSignedPreKey(int signedPreKeyId) throws InvalidKeyIdException {
try {
if (!store.containsKey(signedPreKeyId)) {
throw new InvalidKeyIdException("No such signedprekeyrecord! " + signedPreKeyId);
}
return new SignedPreKeyRecord(store.get(signedPreKeyId));
} catch (IOException e) {
throw new AssertionError(e);
}
}
@Override
public List<SignedPreKeyRecord> loadSignedPreKeys() {
try {
var results = new LinkedList<SignedPreKeyRecord>();
for (var serialized : store.values()) {
results.add(new SignedPreKeyRecord(serialized));
}
return results;
} catch (IOException e) {
throw new AssertionError(e);
}
}
@Override
public void storeSignedPreKey(int signedPreKeyId, SignedPreKeyRecord record) {
store.put(signedPreKeyId, record.serialize());
}
@Override
public boolean containsSignedPreKey(int signedPreKeyId) {
return store.containsKey(signedPreKeyId);
}
@Override
public void removeSignedPreKey(int signedPreKeyId) {
store.remove(signedPreKeyId);
}
public static class JsonSignedPreKeyStoreDeserializer extends JsonDeserializer<JsonSignedPreKeyStore> {
@Override
public JsonSignedPreKeyStore deserialize(
JsonParser jsonParser, DeserializationContext deserializationContext
) throws IOException {
JsonNode node = jsonParser.getCodec().readTree(jsonParser);
var preKeyMap = new HashMap<Integer, byte[]>();
if (node.isArray()) {
for (var preKey : node) {
final var preKeyId = preKey.get("id").asInt();
final var preKeyRecord = Base64.getDecoder().decode(preKey.get("record").asText());
preKeyMap.put(preKeyId, preKeyRecord);
}
}
var keyStore = new JsonSignedPreKeyStore();
keyStore.addSignedPreKeys(preKeyMap);
return keyStore;
}
}
public static class JsonSignedPreKeyStoreSerializer extends JsonSerializer<JsonSignedPreKeyStore> {
@Override
public void serialize(
JsonSignedPreKeyStore jsonPreKeyStore, JsonGenerator json, SerializerProvider serializerProvider
) throws IOException {
json.writeStartArray();
for (var signedPreKey : jsonPreKeyStore.store.entrySet()) {
json.writeStartObject();
json.writeNumberField("id", signedPreKey.getKey());
json.writeStringField("record", Base64.getEncoder().encodeToString(signedPreKey.getValue()));
json.writeEndObject();
}
json.writeEndArray();
}
}
}

View file

@ -14,7 +14,7 @@ public class LegacyJsonPreKeyStore {
private final Map<Integer, byte[]> preKeys; private final Map<Integer, byte[]> preKeys;
public LegacyJsonPreKeyStore(final Map<Integer, byte[]> preKeys) { private LegacyJsonPreKeyStore(final Map<Integer, byte[]> preKeys) {
this.preKeys = preKeys; this.preKeys = preKeys;
} }

View file

@ -16,9 +16,10 @@ import java.util.List;
public class LegacyJsonSessionStore { public class LegacyJsonSessionStore {
private final List<SessionInfo> sessions = new ArrayList<>(); private final List<SessionInfo> sessions;
public LegacyJsonSessionStore() { private LegacyJsonSessionStore(final List<SessionInfo> sessions) {
this.sessions = sessions;
} }
public List<SessionInfo> getSessions() { public List<SessionInfo> getSessions() {
@ -33,7 +34,7 @@ public class LegacyJsonSessionStore {
) throws IOException { ) throws IOException {
JsonNode node = jsonParser.getCodec().readTree(jsonParser); JsonNode node = jsonParser.getCodec().readTree(jsonParser);
var sessionStore = new LegacyJsonSessionStore(); var sessions = new ArrayList<SessionInfo>();
if (node.isArray()) { if (node.isArray()) {
for (var session : node) { for (var session : node) {
@ -50,11 +51,11 @@ public class LegacyJsonSessionStore {
final var deviceId = session.get("deviceId").asInt(); final var deviceId = session.get("deviceId").asInt();
final var record = Base64.getDecoder().decode(session.get("record").asText()); final var record = Base64.getDecoder().decode(session.get("record").asText());
var sessionInfo = new SessionInfo(serviceAddress, deviceId, record); var sessionInfo = new SessionInfo(serviceAddress, deviceId, record);
sessionStore.sessions.add(sessionInfo); sessions.add(sessionInfo);
} }
} }
return sessionStore; return new LegacyJsonSessionStore(sessions);
} }
} }
} }

View file

@ -0,0 +1,45 @@
package org.asamk.signal.manager.storage.protocol;
import com.fasterxml.jackson.core.JsonParser;
import com.fasterxml.jackson.databind.DeserializationContext;
import com.fasterxml.jackson.databind.JsonDeserializer;
import com.fasterxml.jackson.databind.JsonNode;
import java.io.IOException;
import java.util.Base64;
import java.util.HashMap;
import java.util.Map;
public class LegacyJsonSignedPreKeyStore {
private final Map<Integer, byte[]> signedPreKeys;
private LegacyJsonSignedPreKeyStore(final Map<Integer, byte[]> signedPreKeys) {
this.signedPreKeys = signedPreKeys;
}
public Map<Integer, byte[]> getSignedPreKeys() {
return signedPreKeys;
}
public static class JsonSignedPreKeyStoreDeserializer extends JsonDeserializer<LegacyJsonSignedPreKeyStore> {
@Override
public LegacyJsonSignedPreKeyStore deserialize(
JsonParser jsonParser, DeserializationContext deserializationContext
) throws IOException {
JsonNode node = jsonParser.getCodec().readTree(jsonParser);
var preKeyMap = new HashMap<Integer, byte[]>();
if (node.isArray()) {
for (var preKey : node) {
final var preKeyId = preKey.get("id").asInt();
final var preKeyRecord = Base64.getDecoder().decode(preKey.get("record").asText());
preKeyMap.put(preKeyId, preKeyRecord);
}
}
return new LegacyJsonSignedPreKeyStore(preKeyMap);
}
}
}