Refactor sessions store

This commit is contained in:
AsamK 2021-04-15 22:33:35 +02:00
parent 9f5347964b
commit f77519445c
9 changed files with 449 additions and 258 deletions

View file

@ -263,7 +263,7 @@ public class Manager implements Closeable {
}
private IdentityKeyPair getIdentityKeyPair() {
return account.getSignalProtocolStore().getIdentityKeyPair();
return account.getIdentityKeyPair();
}
public int getDeviceId() {
@ -336,7 +336,7 @@ public class Manager implements Closeable {
public void updateAccountAttributes() throws IOException {
accountManager.setAccountAttributes(null,
account.getSignalProtocolStore().getLocalRegistrationId(),
account.getLocalRegistrationId(),
true,
// set legacy pin only if no KBS master key is set
account.getPinMasterKey() == null ? account.getRegistrationLockPin() : null,
@ -1441,7 +1441,7 @@ public class Manager implements Closeable {
}
private void handleEndSession(SignalServiceAddress source) {
account.getSignalProtocolStore().deleteAllSessions(source);
account.getSessionStore().deleteAllSessions(source.getIdentifier());
}
private List<HandleAction> handleSignalServiceDataMessage(

View file

@ -163,10 +163,10 @@ public class RegistrationManager implements Closeable {
account.setRegistered(true);
account.setUuid(UuidUtil.parseOrNull(response.getUuid()));
account.setRegistrationLockPin(pin);
account.getSignalProtocolStore().archiveAllSessions();
account.getSessionStore().archiveAllSessions();
account.getSignalProtocolStore()
.saveIdentity(account.getSelfAddress(),
account.getSignalProtocolStore().getIdentityKeyPair().getPublicKey(),
account.getIdentityKeyPair().getPublicKey(),
TrustLevel.TRUSTED_VERIFIED);
Manager m = null;
@ -194,7 +194,7 @@ public class RegistrationManager implements Closeable {
) throws IOException {
return accountManager.verifyAccountWithCode(verificationCode,
null,
account.getSignalProtocolStore().getLocalRegistrationId(),
account.getLocalRegistrationId(),
true,
legacyPin,
registrationLock,

View file

@ -21,6 +21,7 @@ import org.asamk.signal.manager.storage.protocol.SignalServiceAddressResolver;
import org.asamk.signal.manager.storage.recipients.LegacyRecipientStore;
import org.asamk.signal.manager.storage.recipients.RecipientId;
import org.asamk.signal.manager.storage.recipients.RecipientStore;
import org.asamk.signal.manager.storage.sessions.SessionStore;
import org.asamk.signal.manager.storage.stickers.StickerStore;
import org.asamk.signal.manager.storage.threads.LegacyJsonThreadStore;
import org.asamk.signal.manager.util.IOUtils;
@ -31,7 +32,9 @@ import org.signal.zkgroup.profiles.ProfileKey;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.libsignal.IdentityKeyPair;
import org.whispersystems.libsignal.SignalProtocolAddress;
import org.whispersystems.libsignal.state.PreKeyRecord;
import org.whispersystems.libsignal.state.SessionRecord;
import org.whispersystems.libsignal.state.SignedPreKeyRecord;
import org.whispersystems.libsignal.util.Medium;
import org.whispersystems.libsignal.util.Pair;
@ -77,6 +80,7 @@ public class SignalAccount implements Closeable {
private boolean registered = false;
private JsonSignalProtocolStore signalProtocolStore;
private SessionStore sessionStore;
private JsonGroupStore groupStore;
private JsonContactsStore contactStore;
private RecipientStore recipientStore;
@ -125,11 +129,13 @@ public class SignalAccount implements Closeable {
account.username = username;
account.profileKey = profileKey;
account.signalProtocolStore = new JsonSignalProtocolStore(identityKey, registrationId);
account.groupStore = new JsonGroupStore(getGroupCachePath(dataPath, username));
account.contactStore = new JsonContactsStore();
account.recipientStore = RecipientStore.load(getRecipientsStoreFile(dataPath, username),
account::mergeRecipients);
account.sessionStore = new SessionStore(getSessionsPath(dataPath, username),
account.recipientStore::resolveRecipient);
account.signalProtocolStore = new JsonSignalProtocolStore(identityKey, registrationId, account.sessionStore);
account.profileStore = new ProfileStore();
account.stickerStore = new StickerStore();
@ -166,11 +172,13 @@ public class SignalAccount implements Closeable {
account.password = password;
account.profileKey = profileKey;
account.deviceId = deviceId;
account.signalProtocolStore = new JsonSignalProtocolStore(identityKey, registrationId);
account.groupStore = new JsonGroupStore(getGroupCachePath(dataPath, username));
account.contactStore = new JsonContactsStore();
account.recipientStore = RecipientStore.load(getRecipientsStoreFile(dataPath, username),
account::mergeRecipients);
account.sessionStore = new SessionStore(getSessionsPath(dataPath, username),
account.recipientStore::resolveRecipient);
account.signalProtocolStore = new JsonSignalProtocolStore(identityKey, registrationId, account.sessionStore);
account.profileStore = new ProfileStore();
account.stickerStore = new StickerStore();
@ -210,7 +218,7 @@ public class SignalAccount implements Closeable {
}
private void mergeRecipients(RecipientId recipientId, RecipientId toBeMergedRecipientId) {
// TODO
sessionStore.mergeRecipients(recipientId, toBeMergedRecipientId);
}
public static File getFileName(File dataPath, String username) {
@ -229,6 +237,10 @@ public class SignalAccount implements Closeable {
return new File(getUserPath(dataPath, username), "group-cache");
}
private static File getSessionsPath(File dataPath, String username) {
return new File(getUserPath(dataPath, username), "sessions");
}
private static File getRecipientsStoreFile(File dataPath, String username) {
return new File(getUserPath(dataPath, username), "recipients-store");
}
@ -304,6 +316,19 @@ public class SignalAccount implements Closeable {
signalProtocolStore = jsonProcessor.convertValue(Utils.getNotNullNode(rootNode, "axolotlStore"),
JsonSignalProtocolStore.class);
sessionStore = new SessionStore(getSessionsPath(dataPath, username), recipientStore::resolveRecipient);
if (signalProtocolStore.getLegacySessionStore() != null) {
logger.debug("Migrating legacy session store.");
for (var session : signalProtocolStore.getLegacySessionStore().getSessions()) {
try {
sessionStore.storeSession(new SignalProtocolAddress(session.address.getIdentifier(),
session.deviceId), new SessionRecord(session.sessionRecord));
} catch (IOException e) {
logger.warn("Failed to migrate session, ignoring", e);
}
}
}
signalProtocolStore.setSessionStore(sessionStore);
registered = Utils.getNotNullNode(rootNode, "registered").asBoolean();
var groupStoreNode = rootNode.get("groupStore");
if (groupStoreNode != null) {
@ -355,10 +380,6 @@ public class SignalAccount implements Closeable {
}
}
for (var session : signalProtocolStore.getSessions()) {
session.address = recipientStore.resolveServiceAddress(session.address);
}
for (var identity : signalProtocolStore.getIdentities()) {
identity.setAddress(recipientStore.resolveServiceAddress(identity.getAddress()));
}
@ -464,6 +485,10 @@ public class SignalAccount implements Closeable {
return signalProtocolStore;
}
public SessionStore getSessionStore() {
return sessionStore;
}
public JsonGroupStore getGroupStore() {
return groupStore;
}
@ -516,6 +541,14 @@ public class SignalAccount implements Closeable {
return deviceId == SignalServiceAddress.DEFAULT_DEVICE_ID;
}
public IdentityKeyPair getIdentityKeyPair() {
return signalProtocolStore.getIdentityKeyPair();
}
public int getLocalRegistrationId() {
return signalProtocolStore.getLocalRegistrationId();
}
public String getPassword() {
return password;
}

View file

@ -1,214 +0,0 @@
package org.asamk.signal.manager.storage.protocol;
import com.fasterxml.jackson.core.JsonGenerator;
import com.fasterxml.jackson.core.JsonParser;
import com.fasterxml.jackson.databind.DeserializationContext;
import com.fasterxml.jackson.databind.JsonDeserializer;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.JsonSerializer;
import com.fasterxml.jackson.databind.SerializerProvider;
import org.asamk.signal.manager.util.Utils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.libsignal.SignalProtocolAddress;
import org.whispersystems.libsignal.protocol.CiphertextMessage;
import org.whispersystems.libsignal.state.SessionRecord;
import org.whispersystems.signalservice.api.SignalServiceSessionStore;
import org.whispersystems.signalservice.api.push.SignalServiceAddress;
import org.whispersystems.signalservice.api.util.UuidUtil;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Base64;
import java.util.LinkedList;
import java.util.List;
class JsonSessionStore implements SignalServiceSessionStore {
private final static Logger logger = LoggerFactory.getLogger(JsonSessionStore.class);
private final List<SessionInfo> sessions = new ArrayList<>();
private SignalServiceAddressResolver resolver;
public JsonSessionStore() {
}
public void setResolver(final SignalServiceAddressResolver resolver) {
this.resolver = resolver;
}
private SignalServiceAddress resolveSignalServiceAddress(String identifier) {
if (resolver != null) {
return resolver.resolveSignalServiceAddress(identifier);
} else {
return Utils.getSignalServiceAddressFromIdentifier(identifier);
}
}
@Override
public synchronized SessionRecord loadSession(SignalProtocolAddress address) {
var serviceAddress = resolveSignalServiceAddress(address.getName());
for (var info : sessions) {
if (info.address.matches(serviceAddress) && info.deviceId == address.getDeviceId()) {
try {
return new SessionRecord(info.sessionRecord);
} catch (IOException e) {
logger.warn("Failed to load session, resetting session: {}", e.getMessage());
return new SessionRecord();
}
}
}
return new SessionRecord();
}
public synchronized List<SessionInfo> getSessions() {
return sessions;
}
@Override
public synchronized List<Integer> getSubDeviceSessions(String name) {
var serviceAddress = resolveSignalServiceAddress(name);
var deviceIds = new LinkedList<Integer>();
for (var info : sessions) {
if (info.address.matches(serviceAddress) && info.deviceId != 1) {
deviceIds.add(info.deviceId);
}
}
return deviceIds;
}
@Override
public synchronized void storeSession(SignalProtocolAddress address, SessionRecord record) {
var serviceAddress = resolveSignalServiceAddress(address.getName());
for (var info : sessions) {
if (info.address.matches(serviceAddress) && info.deviceId == address.getDeviceId()) {
if (!info.address.getUuid().isPresent() || !info.address.getNumber().isPresent()) {
info.address = serviceAddress;
}
info.sessionRecord = record.serialize();
return;
}
}
sessions.add(new SessionInfo(serviceAddress, address.getDeviceId(), record.serialize()));
}
@Override
public synchronized boolean containsSession(SignalProtocolAddress address) {
var serviceAddress = resolveSignalServiceAddress(address.getName());
for (var info : sessions) {
if (info.address.matches(serviceAddress) && info.deviceId == address.getDeviceId()) {
final SessionRecord sessionRecord;
try {
sessionRecord = new SessionRecord(info.sessionRecord);
} catch (IOException e) {
logger.warn("Failed to check session: {}", e.getMessage());
return false;
}
return sessionRecord.hasSenderChain()
&& sessionRecord.getSessionVersion() == CiphertextMessage.CURRENT_VERSION;
}
}
return false;
}
@Override
public synchronized void deleteSession(SignalProtocolAddress address) {
var serviceAddress = resolveSignalServiceAddress(address.getName());
sessions.removeIf(info -> info.address.matches(serviceAddress) && info.deviceId == address.getDeviceId());
}
@Override
public synchronized void deleteAllSessions(String name) {
var serviceAddress = resolveSignalServiceAddress(name);
deleteAllSessions(serviceAddress);
}
public synchronized void deleteAllSessions(SignalServiceAddress serviceAddress) {
sessions.removeIf(info -> info.address.matches(serviceAddress));
}
@Override
public void archiveSession(final SignalProtocolAddress address) {
final var sessionRecord = loadSession(address);
if (sessionRecord == null) {
return;
}
sessionRecord.archiveCurrentState();
storeSession(address, sessionRecord);
}
public void archiveAllSessions() {
for (var info : sessions) {
try {
final var sessionRecord = new SessionRecord(info.sessionRecord);
sessionRecord.archiveCurrentState();
info.sessionRecord = sessionRecord.serialize();
} catch (IOException ignored) {
}
}
}
public static class JsonSessionStoreDeserializer extends JsonDeserializer<JsonSessionStore> {
@Override
public JsonSessionStore deserialize(
JsonParser jsonParser, DeserializationContext deserializationContext
) throws IOException {
JsonNode node = jsonParser.getCodec().readTree(jsonParser);
var sessionStore = new JsonSessionStore();
if (node.isArray()) {
for (var session : node) {
var sessionName = session.hasNonNull("name") ? session.get("name").asText() : null;
if (UuidUtil.isUuid(sessionName)) {
// Ignore sessions that were incorrectly created with UUIDs as name
continue;
}
var uuid = session.hasNonNull("uuid") ? UuidUtil.parseOrNull(session.get("uuid").asText()) : null;
final var serviceAddress = uuid == null
? Utils.getSignalServiceAddressFromIdentifier(sessionName)
: new SignalServiceAddress(uuid, sessionName);
final var deviceId = session.get("deviceId").asInt();
final var record = Base64.getDecoder().decode(session.get("record").asText());
var sessionInfo = new SessionInfo(serviceAddress, deviceId, record);
sessionStore.sessions.add(sessionInfo);
}
}
return sessionStore;
}
}
public static class JsonSessionStoreSerializer extends JsonSerializer<JsonSessionStore> {
@Override
public void serialize(
JsonSessionStore jsonSessionStore, JsonGenerator json, SerializerProvider serializerProvider
) throws IOException {
json.writeStartArray();
for (var sessionInfo : jsonSessionStore.sessions) {
json.writeStartObject();
if (sessionInfo.address.getNumber().isPresent()) {
json.writeStringField("name", sessionInfo.address.getNumber().get());
}
if (sessionInfo.address.getUuid().isPresent()) {
json.writeStringField("uuid", sessionInfo.address.getUuid().get().toString());
}
json.writeNumberField("deviceId", sessionInfo.deviceId);
json.writeStringField("record", Base64.getEncoder().encodeToString(sessionInfo.sessionRecord));
json.writeEndObject();
}
json.writeEndArray();
}
}
}

View file

@ -1,10 +1,12 @@
package org.asamk.signal.manager.storage.protocol;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
import org.asamk.signal.manager.TrustLevel;
import org.asamk.signal.manager.storage.sessions.SessionStore;
import org.whispersystems.libsignal.IdentityKey;
import org.whispersystems.libsignal.IdentityKeyPair;
import org.whispersystems.libsignal.InvalidKeyIdException;
@ -17,6 +19,7 @@ import org.whispersystems.signalservice.api.push.SignalServiceAddress;
import java.util.List;
@JsonIgnoreProperties(value = "sessionStore", allowSetters = true)
public class JsonSignalProtocolStore implements SignalServiceProtocolStore {
@JsonProperty("preKeys")
@ -25,9 +28,8 @@ public class JsonSignalProtocolStore implements SignalServiceProtocolStore {
private JsonPreKeyStore preKeyStore;
@JsonProperty("sessionStore")
@JsonDeserialize(using = JsonSessionStore.JsonSessionStoreDeserializer.class)
@JsonSerialize(using = JsonSessionStore.JsonSessionStoreSerializer.class)
private JsonSessionStore sessionStore;
@JsonDeserialize(using = LegacyJsonSessionStore.JsonSessionStoreDeserializer.class)
private LegacyJsonSessionStore legacySessionStore;
@JsonProperty("signedPreKeyStore")
@JsonDeserialize(using = JsonSignedPreKeyStore.JsonSignedPreKeyStoreDeserializer.class)
@ -39,33 +41,30 @@ public class JsonSignalProtocolStore implements SignalServiceProtocolStore {
@JsonSerialize(using = JsonIdentityKeyStore.JsonIdentityKeyStoreSerializer.class)
private JsonIdentityKeyStore identityKeyStore;
private SessionStore sessionStore;
public JsonSignalProtocolStore() {
}
public JsonSignalProtocolStore(
JsonPreKeyStore preKeyStore,
JsonSessionStore sessionStore,
JsonSignedPreKeyStore signedPreKeyStore,
JsonIdentityKeyStore identityKeyStore
) {
this.preKeyStore = preKeyStore;
this.sessionStore = sessionStore;
this.signedPreKeyStore = signedPreKeyStore;
this.identityKeyStore = identityKeyStore;
}
public JsonSignalProtocolStore(IdentityKeyPair identityKeyPair, int registrationId) {
public JsonSignalProtocolStore(IdentityKeyPair identityKeyPair, int registrationId, SessionStore sessionStore) {
preKeyStore = new JsonPreKeyStore();
sessionStore = new JsonSessionStore();
this.sessionStore = sessionStore;
signedPreKeyStore = new JsonSignedPreKeyStore();
this.identityKeyStore = new JsonIdentityKeyStore(identityKeyPair, registrationId);
}
public void setResolver(final SignalServiceAddressResolver resolver) {
sessionStore.setResolver(resolver);
identityKeyStore.setResolver(resolver);
}
public void setSessionStore(final SessionStore sessionStore) {
this.sessionStore = sessionStore;
}
public LegacyJsonSessionStore getLegacySessionStore() {
return legacySessionStore;
}
@Override
public IdentityKeyPair getIdentityKeyPair() {
return identityKeyStore.getIdentityKeyPair();
@ -142,10 +141,6 @@ public class JsonSignalProtocolStore implements SignalServiceProtocolStore {
return sessionStore.loadSession(address);
}
public List<SessionInfo> getSessions() {
return sessionStore.getSessions();
}
@Override
public List<Integer> getSubDeviceSessions(String name) {
return sessionStore.getSubDeviceSessions(name);
@ -171,19 +166,11 @@ public class JsonSignalProtocolStore implements SignalServiceProtocolStore {
sessionStore.deleteAllSessions(name);
}
public void deleteAllSessions(SignalServiceAddress serviceAddress) {
sessionStore.deleteAllSessions(serviceAddress);
}
@Override
public void archiveSession(final SignalProtocolAddress address) {
sessionStore.archiveSession(address);
}
public void archiveAllSessions() {
sessionStore.archiveAllSessions();
}
@Override
public SignedPreKeyRecord loadSignedPreKey(int signedPreKeyId) throws InvalidKeyIdException {
return signedPreKeyStore.loadSignedPreKey(signedPreKeyId);

View file

@ -0,0 +1,60 @@
package org.asamk.signal.manager.storage.protocol;
import com.fasterxml.jackson.core.JsonParser;
import com.fasterxml.jackson.databind.DeserializationContext;
import com.fasterxml.jackson.databind.JsonDeserializer;
import com.fasterxml.jackson.databind.JsonNode;
import org.asamk.signal.manager.util.Utils;
import org.whispersystems.signalservice.api.push.SignalServiceAddress;
import org.whispersystems.signalservice.api.util.UuidUtil;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Base64;
import java.util.List;
public class LegacyJsonSessionStore {
private final List<SessionInfo> sessions = new ArrayList<>();
public LegacyJsonSessionStore() {
}
public List<SessionInfo> getSessions() {
return sessions;
}
public static class JsonSessionStoreDeserializer extends JsonDeserializer<LegacyJsonSessionStore> {
@Override
public LegacyJsonSessionStore deserialize(
JsonParser jsonParser, DeserializationContext deserializationContext
) throws IOException {
JsonNode node = jsonParser.getCodec().readTree(jsonParser);
var sessionStore = new LegacyJsonSessionStore();
if (node.isArray()) {
for (var session : node) {
var sessionName = session.hasNonNull("name") ? session.get("name").asText() : null;
if (UuidUtil.isUuid(sessionName)) {
// Ignore sessions that were incorrectly created with UUIDs as name
continue;
}
var uuid = session.hasNonNull("uuid") ? UuidUtil.parseOrNull(session.get("uuid").asText()) : null;
final var serviceAddress = uuid == null
? Utils.getSignalServiceAddressFromIdentifier(sessionName)
: new SignalServiceAddress(uuid, sessionName);
final var deviceId = session.get("deviceId").asInt();
final var record = Base64.getDecoder().decode(session.get("record").asText());
var sessionInfo = new SessionInfo(serviceAddress, deviceId, record);
sessionStore.sessions.add(sessionInfo);
}
}
return sessionStore;
}
}
}

View file

@ -8,7 +8,11 @@ public class RecipientId {
this.id = id;
}
long getId() {
public static RecipientId of(long id) {
return new RecipientId(id);
}
public long getId() {
return id;
}

View file

@ -0,0 +1,8 @@
package org.asamk.signal.manager.storage.recipients;
import org.whispersystems.signalservice.api.push.SignalServiceAddress;
public interface RecipientResolver {
RecipientId resolveRecipient(SignalServiceAddress address);
}

View file

@ -0,0 +1,313 @@
package org.asamk.signal.manager.storage.sessions;
import org.asamk.signal.manager.storage.recipients.RecipientId;
import org.asamk.signal.manager.storage.recipients.RecipientResolver;
import org.asamk.signal.manager.util.IOUtils;
import org.asamk.signal.manager.util.Utils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.libsignal.SignalProtocolAddress;
import org.whispersystems.libsignal.protocol.CiphertextMessage;
import org.whispersystems.libsignal.state.SessionRecord;
import org.whispersystems.signalservice.api.SignalServiceSessionStore;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.nio.file.Files;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
public class SessionStore implements SignalServiceSessionStore {
private final static Logger logger = LoggerFactory.getLogger(SessionStore.class);
private final Map<Key, SessionRecord> cachedSessions = new HashMap<>();
private final File sessionsPath;
private final RecipientResolver resolver;
public SessionStore(
final File sessionsPath, final RecipientResolver resolver
) {
this.sessionsPath = sessionsPath;
this.resolver = resolver;
}
@Override
public SessionRecord loadSession(SignalProtocolAddress address) {
final var key = getKey(address);
synchronized (cachedSessions) {
final var session = loadSessionLocked(key);
if (session == null) {
return new SessionRecord();
}
return session;
}
}
@Override
public List<Integer> getSubDeviceSessions(String name) {
final var recipientId = resolveRecipient(name);
synchronized (cachedSessions) {
return getKeysLocked(recipientId).stream()
// get all sessions for recipient except main device session
.filter(key -> key.getDeviceId() != 1 && key.getRecipientId().equals(recipientId))
.map(Key::getDeviceId)
.collect(Collectors.toList());
}
}
@Override
public void storeSession(SignalProtocolAddress address, SessionRecord session) {
final var key = getKey(address);
synchronized (cachedSessions) {
storeSessionLocked(key, session);
}
}
@Override
public boolean containsSession(SignalProtocolAddress address) {
final var key = getKey(address);
synchronized (cachedSessions) {
final var session = loadSessionLocked(key);
if (session == null) {
return false;
}
return session.hasSenderChain() && session.getSessionVersion() == CiphertextMessage.CURRENT_VERSION;
}
}
@Override
public void deleteSession(SignalProtocolAddress address) {
final var key = getKey(address);
synchronized (cachedSessions) {
deleteSessionLocked(key);
}
}
@Override
public void deleteAllSessions(String name) {
final var recipientId = resolveRecipient(name);
deleteAllSessions(recipientId);
}
public void deleteAllSessions(RecipientId recipientId) {
synchronized (cachedSessions) {
final var keys = getKeysLocked(recipientId);
for (var key : keys) {
deleteSessionLocked(key);
}
}
}
@Override
public void archiveSession(final SignalProtocolAddress address) {
final var key = getKey(address);
synchronized (cachedSessions) {
archiveSessionLocked(key);
}
}
public void archiveAllSessions() {
synchronized (cachedSessions) {
final var keys = getKeysLocked();
for (var key : keys) {
archiveSessionLocked(key);
}
}
}
public void mergeRecipients(RecipientId recipientId, RecipientId toBeMergedRecipientId) {
synchronized (cachedSessions) {
final var otherHasSession = getKeysLocked(toBeMergedRecipientId).size() > 0;
if (!otherHasSession) {
return;
}
final var hasSession = getKeysLocked(recipientId).size() > 0;
if (hasSession) {
logger.debug("To be merged recipient had sessions, deleting.");
deleteAllSessions(toBeMergedRecipientId);
} else {
logger.debug("To be merged recipient had sessions, re-assigning to the new recipient.");
final var keys = getKeysLocked(toBeMergedRecipientId);
for (var key : keys) {
final var session = loadSessionLocked(key);
deleteSessionLocked(key);
if (session == null) {
continue;
}
final var newKey = new Key(recipientId, key.getDeviceId());
storeSessionLocked(newKey, session);
}
}
}
}
/**
* @param identifier can be either a serialized uuid or a e164 phone number
*/
private RecipientId resolveRecipient(String identifier) {
return resolver.resolveRecipient(Utils.getSignalServiceAddressFromIdentifier(identifier));
}
private Key getKey(final SignalProtocolAddress address) {
final var recipientId = resolveRecipient(address.getName());
return new Key(recipientId, address.getDeviceId());
}
private List<Key> getKeysLocked(RecipientId recipientId) {
final var files = sessionsPath.listFiles((_file, s) -> s.startsWith(recipientId.getId() + "_"));
if (files == null) {
return List.of();
}
return parseFileNames(files);
}
private Collection<Key> getKeysLocked() {
final var files = sessionsPath.listFiles();
if (files == null) {
return List.of();
}
return parseFileNames(files);
}
final Pattern sessionFileNamePattern = Pattern.compile("([0-9]+)_([0-9]+)");
private List<Key> parseFileNames(final File[] files) {
return Arrays.stream(files)
.map(f -> sessionFileNamePattern.matcher(f.getName()))
.filter(Matcher::matches)
.map(matcher -> new Key(RecipientId.of(Long.parseLong(matcher.group(1))),
Integer.parseInt(matcher.group(2))))
.collect(Collectors.toList());
}
private File getSessionPath(Key key) {
try {
IOUtils.createPrivateDirectories(sessionsPath);
} catch (IOException e) {
throw new AssertionError("Failed to create sessions path", e);
}
return new File(sessionsPath, key.getRecipientId().getId() + "_" + key.getDeviceId());
}
private SessionRecord loadSessionLocked(final Key key) {
{
final var session = cachedSessions.get(key);
if (session != null) {
return session;
}
}
final var file = getSessionPath(key);
if (!file.exists()) {
return null;
}
try (var inputStream = new FileInputStream(file)) {
final var session = new SessionRecord(inputStream.readAllBytes());
cachedSessions.put(key, session);
return session;
} catch (IOException e) {
logger.warn("Failed to load session, resetting session: {}", e.getMessage());
return null;
}
}
private void storeSessionLocked(final Key key, final SessionRecord session) {
cachedSessions.put(key, session);
final var file = getSessionPath(key);
try {
try (var outputStream = new FileOutputStream(file)) {
outputStream.write(session.serialize());
}
} catch (IOException e) {
logger.warn("Failed to store session, trying to delete file and retry: {}", e.getMessage());
try {
Files.delete(file.toPath());
try (var outputStream = new FileOutputStream(file)) {
outputStream.write(session.serialize());
}
} catch (IOException e2) {
logger.error("Failed to store session file {}: {}", file, e2.getMessage());
}
}
}
private void archiveSessionLocked(final Key key) {
final var session = loadSessionLocked(key);
if (session == null) {
return;
}
session.archiveCurrentState();
storeSessionLocked(key, session);
}
private void deleteSessionLocked(final Key key) {
cachedSessions.remove(key);
final var file = getSessionPath(key);
if (!file.exists()) {
return;
}
try {
Files.delete(file.toPath());
} catch (IOException e) {
logger.error("Failed to delete session file {}: {}", file, e.getMessage());
}
}
private static final class Key {
private final RecipientId recipientId;
private final int deviceId;
public Key(final RecipientId recipientId, final int deviceId) {
this.recipientId = recipientId;
this.deviceId = deviceId;
}
public RecipientId getRecipientId() {
return recipientId;
}
public int getDeviceId() {
return deviceId;
}
@Override
public boolean equals(final Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
final var key = (Key) o;
if (deviceId != key.deviceId) return false;
return recipientId.equals(key.recipientId);
}
@Override
public int hashCode() {
int result = recipientId.hashCode();
result = 31 * result + deviceId;
return result;
}
}
}