Refresh pre keys for PNI identity

Fixes #930
This commit is contained in:
AsamK 2022-04-11 20:05:02 +02:00
parent 2a20e70aab
commit 945ff44de3
2 changed files with 132 additions and 51 deletions

View file

@ -45,32 +45,31 @@ public class PreKeyHelper {
} }
public void refreshPreKeys(ServiceIdType serviceIdType) throws IOException { public void refreshPreKeys(ServiceIdType serviceIdType) throws IOException {
if (serviceIdType != ServiceIdType.ACI) { final var oneTimePreKeys = generatePreKeys(serviceIdType);
// TODO implement final var identityKeyPair = account.getIdentityKeyPair(serviceIdType);
if (identityKeyPair == null) {
return; return;
} }
var oneTimePreKeys = generatePreKeys(); final var signedPreKeyRecord = generateSignedPreKey(serviceIdType, identityKeyPair);
final var identityKeyPair = account.getAciIdentityKeyPair();
var signedPreKeyRecord = generateSignedPreKey(identityKeyPair);
dependencies.getAccountManager() dependencies.getAccountManager()
.setPreKeys(serviceIdType, identityKeyPair.getPublicKey(), signedPreKeyRecord, oneTimePreKeys); .setPreKeys(serviceIdType, identityKeyPair.getPublicKey(), signedPreKeyRecord, oneTimePreKeys);
} }
private List<PreKeyRecord> generatePreKeys() { private List<PreKeyRecord> generatePreKeys(ServiceIdType serviceIdType) {
final var offset = account.getPreKeyIdOffset(); final var offset = account.getPreKeyIdOffset(serviceIdType);
var records = KeyUtils.generatePreKeyRecords(offset, ServiceConfig.PREKEY_BATCH_SIZE); var records = KeyUtils.generatePreKeyRecords(offset, ServiceConfig.PREKEY_BATCH_SIZE);
account.addPreKeys(records); account.addPreKeys(serviceIdType, records);
return records; return records;
} }
private SignedPreKeyRecord generateSignedPreKey(IdentityKeyPair identityKeyPair) { private SignedPreKeyRecord generateSignedPreKey(ServiceIdType serviceIdType, IdentityKeyPair identityKeyPair) {
final var signedPreKeyId = account.getNextSignedPreKeyId(); final var signedPreKeyId = account.getNextSignedPreKeyId(serviceIdType);
var record = KeyUtils.generateSignedPreKeyRecord(identityKeyPair, signedPreKeyId); var record = KeyUtils.generateSignedPreKeyRecord(identityKeyPair, signedPreKeyId);
account.addSignedPreKey(record); account.addSignedPreKey(serviceIdType, record);
return record; return record;
} }

View file

@ -53,6 +53,7 @@ import org.whispersystems.signalservice.api.push.ACI;
import org.whispersystems.signalservice.api.push.DistributionId; import org.whispersystems.signalservice.api.push.DistributionId;
import org.whispersystems.signalservice.api.push.PNI; import org.whispersystems.signalservice.api.push.PNI;
import org.whispersystems.signalservice.api.push.ServiceId; import org.whispersystems.signalservice.api.push.ServiceId;
import org.whispersystems.signalservice.api.push.ServiceIdType;
import org.whispersystems.signalservice.api.push.SignalServiceAddress; import org.whispersystems.signalservice.api.push.SignalServiceAddress;
import org.whispersystems.signalservice.api.storage.StorageKey; import org.whispersystems.signalservice.api.storage.StorageKey;
import org.whispersystems.signalservice.api.util.CredentialsProvider; import org.whispersystems.signalservice.api.util.CredentialsProvider;
@ -106,8 +107,10 @@ public class SignalAccount implements Closeable {
private StorageKey storageKey; private StorageKey storageKey;
private long storageManifestVersion = -1; private long storageManifestVersion = -1;
private ProfileKey profileKey; private ProfileKey profileKey;
private int preKeyIdOffset = 1; private int aciPreKeyIdOffset = 1;
private int nextSignedPreKeyId = 1; private int aciNextSignedPreKeyId = 1;
private int pniPreKeyIdOffset = 1;
private int pniNextSignedPreKeyId = 1;
private IdentityKeyPair aciIdentityKeyPair; private IdentityKeyPair aciIdentityKeyPair;
private IdentityKeyPair pniIdentityKeyPair; private IdentityKeyPair pniIdentityKeyPair;
private int localRegistrationId; private int localRegistrationId;
@ -117,8 +120,10 @@ public class SignalAccount implements Closeable {
private boolean registered = false; private boolean registered = false;
private SignalProtocolStore signalProtocolStore; private SignalProtocolStore signalProtocolStore;
private PreKeyStore preKeyStore; private PreKeyStore aciPreKeyStore;
private SignedPreKeyStore signedPreKeyStore; private SignedPreKeyStore aciSignedPreKeyStore;
private PreKeyStore pniPreKeyStore;
private SignedPreKeyStore pniSignedPreKeyStore;
private SessionStore sessionStore; private SessionStore sessionStore;
private IdentityKeyStore identityKeyStore; private IdentityKeyStore identityKeyStore;
private SenderKeyStore senderKeyStore; private SenderKeyStore senderKeyStore;
@ -259,10 +264,14 @@ public class SignalAccount implements Closeable {
} }
private void clearAllPreKeys() { private void clearAllPreKeys() {
this.preKeyIdOffset = new SecureRandom().nextInt(Medium.MAX_VALUE); this.aciPreKeyIdOffset = new SecureRandom().nextInt(Medium.MAX_VALUE);
this.nextSignedPreKeyId = new SecureRandom().nextInt(Medium.MAX_VALUE); this.aciNextSignedPreKeyId = new SecureRandom().nextInt(Medium.MAX_VALUE);
this.getPreKeyStore().removeAllPreKeys(); this.pniPreKeyIdOffset = new SecureRandom().nextInt(Medium.MAX_VALUE);
this.getSignedPreKeyStore().removeAllSignedPreKeys(); this.pniNextSignedPreKeyId = new SecureRandom().nextInt(Medium.MAX_VALUE);
this.getAciPreKeyStore().removeAllPreKeys();
this.getAciSignedPreKeyStore().removeAllSignedPreKeys();
this.getPniPreKeyStore().removeAllPreKeys();
this.getPniSignedPreKeyStore().removeAllSignedPreKeys();
save(); save();
} }
@ -407,14 +416,22 @@ public class SignalAccount implements Closeable {
return new File(getUserPath(dataPath, account), "group-cache"); return new File(getUserPath(dataPath, account), "group-cache");
} }
private static File getPreKeysPath(File dataPath, String account) { private static File getAciPreKeysPath(File dataPath, String account) {
return new File(getUserPath(dataPath, account), "pre-keys"); return new File(getUserPath(dataPath, account), "pre-keys");
} }
private static File getSignedPreKeysPath(File dataPath, String account) { private static File getAciSignedPreKeysPath(File dataPath, String account) {
return new File(getUserPath(dataPath, account), "signed-pre-keys"); return new File(getUserPath(dataPath, account), "signed-pre-keys");
} }
private static File getPniPreKeysPath(File dataPath, String account) {
return new File(getUserPath(dataPath, account), "pre-keys-pni");
}
private static File getPniSignedPreKeysPath(File dataPath, String account) {
return new File(getUserPath(dataPath, account), "signed-pre-keys-pni");
}
private static File getIdentitiesPath(File dataPath, String account) { private static File getIdentitiesPath(File dataPath, String account) {
return new File(getUserPath(dataPath, account), "identities"); return new File(getUserPath(dataPath, account), "identities");
} }
@ -528,14 +545,24 @@ public class SignalAccount implements Closeable {
storageManifestVersion = rootNode.get("storageManifestVersion").asLong(); storageManifestVersion = rootNode.get("storageManifestVersion").asLong();
} }
if (rootNode.hasNonNull("preKeyIdOffset")) { if (rootNode.hasNonNull("preKeyIdOffset")) {
preKeyIdOffset = rootNode.get("preKeyIdOffset").asInt(1); aciPreKeyIdOffset = rootNode.get("preKeyIdOffset").asInt(1);
} else { } else {
preKeyIdOffset = 1; aciPreKeyIdOffset = 1;
} }
if (rootNode.hasNonNull("nextSignedPreKeyId")) { if (rootNode.hasNonNull("nextSignedPreKeyId")) {
nextSignedPreKeyId = rootNode.get("nextSignedPreKeyId").asInt(1); aciNextSignedPreKeyId = rootNode.get("nextSignedPreKeyId").asInt(1);
} else { } else {
nextSignedPreKeyId = 1; aciNextSignedPreKeyId = 1;
}
if (rootNode.hasNonNull("pniPreKeyIdOffset")) {
pniPreKeyIdOffset = rootNode.get("pniPreKeyIdOffset").asInt(1);
} else {
pniPreKeyIdOffset = 1;
}
if (rootNode.hasNonNull("pniNextSignedPreKeyId")) {
pniNextSignedPreKeyId = rootNode.get("pniNextSignedPreKeyId").asInt(1);
} else {
pniNextSignedPreKeyId = 1;
} }
if (rootNode.hasNonNull("profileKey")) { if (rootNode.hasNonNull("profileKey")) {
try { try {
@ -618,7 +645,7 @@ public class SignalAccount implements Closeable {
logger.debug("Migrating legacy pre key store."); logger.debug("Migrating legacy pre key store.");
for (var entry : legacySignalProtocolStore.getLegacyPreKeyStore().getPreKeys().entrySet()) { for (var entry : legacySignalProtocolStore.getLegacyPreKeyStore().getPreKeys().entrySet()) {
try { try {
getPreKeyStore().storePreKey(entry.getKey(), new PreKeyRecord(entry.getValue())); getAciPreKeyStore().storePreKey(entry.getKey(), new PreKeyRecord(entry.getValue()));
} catch (InvalidMessageException e) { } catch (InvalidMessageException e) {
logger.warn("Failed to migrate pre key, ignoring", e); logger.warn("Failed to migrate pre key, ignoring", e);
} }
@ -630,7 +657,8 @@ public class SignalAccount implements Closeable {
logger.debug("Migrating legacy signed pre key store."); logger.debug("Migrating legacy signed pre key store.");
for (var entry : legacySignalProtocolStore.getLegacySignedPreKeyStore().getSignedPreKeys().entrySet()) { for (var entry : legacySignalProtocolStore.getLegacySignedPreKeyStore().getSignedPreKeys().entrySet()) {
try { try {
getSignedPreKeyStore().storeSignedPreKey(entry.getKey(), new SignedPreKeyRecord(entry.getValue())); getAciSignedPreKeyStore().storeSignedPreKey(entry.getKey(),
new SignedPreKeyRecord(entry.getValue()));
} catch (InvalidMessageException e) { } catch (InvalidMessageException e) {
logger.warn("Failed to migrate signed pre key, ignoring", e); logger.warn("Failed to migrate signed pre key, ignoring", e);
} }
@ -813,8 +841,10 @@ public class SignalAccount implements Closeable {
.put("storageKey", .put("storageKey",
storageKey == null ? null : Base64.getEncoder().encodeToString(storageKey.serialize())) storageKey == null ? null : Base64.getEncoder().encodeToString(storageKey.serialize()))
.put("storageManifestVersion", storageManifestVersion == -1 ? null : storageManifestVersion) .put("storageManifestVersion", storageManifestVersion == -1 ? null : storageManifestVersion)
.put("preKeyIdOffset", preKeyIdOffset) .put("preKeyIdOffset", aciPreKeyIdOffset)
.put("nextSignedPreKeyId", nextSignedPreKeyId) .put("nextSignedPreKeyId", aciNextSignedPreKeyId)
.put("pniPreKeyIdOffset", pniPreKeyIdOffset)
.put("pniNextSignedPreKeyId", pniNextSignedPreKeyId)
.put("profileKey", .put("profileKey",
profileKey == null ? null : Base64.getEncoder().encodeToString(profileKey.serialize())) profileKey == null ? null : Base64.getEncoder().encodeToString(profileKey.serialize()))
.put("registered", registered) .put("registered", registered)
@ -852,25 +882,63 @@ public class SignalAccount implements Closeable {
return new Pair<>(fileChannel, lock); return new Pair<>(fileChannel, lock);
} }
public void addPreKeys(List<PreKeyRecord> records) { public void addPreKeys(ServiceIdType serviceIdType, List<PreKeyRecord> records) {
if (serviceIdType.equals(ServiceIdType.ACI)) {
addAciPreKeys(records);
} else {
addPniPreKeys(records);
}
}
public void addAciPreKeys(List<PreKeyRecord> records) {
for (var record : records) { for (var record : records) {
if (preKeyIdOffset != record.getId()) { if (aciPreKeyIdOffset != record.getId()) {
logger.error("Invalid pre key id {}, expected {}", record.getId(), preKeyIdOffset); logger.error("Invalid pre key id {}, expected {}", record.getId(), aciPreKeyIdOffset);
throw new AssertionError("Invalid pre key id"); throw new AssertionError("Invalid pre key id");
} }
getPreKeyStore().storePreKey(record.getId(), record); getAciPreKeyStore().storePreKey(record.getId(), record);
preKeyIdOffset = (preKeyIdOffset + 1) % Medium.MAX_VALUE; aciPreKeyIdOffset = (aciPreKeyIdOffset + 1) % Medium.MAX_VALUE;
} }
save(); save();
} }
public void addSignedPreKey(SignedPreKeyRecord record) { public void addPniPreKeys(List<PreKeyRecord> records) {
if (nextSignedPreKeyId != record.getId()) { for (var record : records) {
logger.error("Invalid signed pre key id {}, expected {}", record.getId(), nextSignedPreKeyId); if (pniPreKeyIdOffset != record.getId()) {
logger.error("Invalid pre key id {}, expected {}", record.getId(), pniPreKeyIdOffset);
throw new AssertionError("Invalid pre key id");
}
getPniPreKeyStore().storePreKey(record.getId(), record);
pniPreKeyIdOffset = (pniPreKeyIdOffset + 1) % Medium.MAX_VALUE;
}
save();
}
public void addSignedPreKey(ServiceIdType serviceIdType, SignedPreKeyRecord record) {
if (serviceIdType.equals(ServiceIdType.ACI)) {
addAciSignedPreKey(record);
} else {
addPniSignedPreKey(record);
}
}
public void addAciSignedPreKey(SignedPreKeyRecord record) {
if (aciNextSignedPreKeyId != record.getId()) {
logger.error("Invalid signed pre key id {}, expected {}", record.getId(), aciNextSignedPreKeyId);
throw new AssertionError("Invalid signed pre key id"); throw new AssertionError("Invalid signed pre key id");
} }
getSignedPreKeyStore().storeSignedPreKey(record.getId(), record); getAciSignedPreKeyStore().storeSignedPreKey(record.getId(), record);
nextSignedPreKeyId = (nextSignedPreKeyId + 1) % Medium.MAX_VALUE; aciNextSignedPreKeyId = (aciNextSignedPreKeyId + 1) % Medium.MAX_VALUE;
save();
}
public void addPniSignedPreKey(SignedPreKeyRecord record) {
if (pniNextSignedPreKeyId != record.getId()) {
logger.error("Invalid signed pre key id {}, expected {}", record.getId(), pniNextSignedPreKeyId);
throw new AssertionError("Invalid signed pre key id");
}
getPniSignedPreKeyStore().storeSignedPreKey(record.getId(), record);
pniNextSignedPreKeyId = (pniNextSignedPreKeyId + 1) % Medium.MAX_VALUE;
save(); save();
} }
@ -906,22 +974,32 @@ public class SignalAccount implements Closeable {
public SignalServiceAccountDataStore getSignalServiceAccountDataStore() { public SignalServiceAccountDataStore getSignalServiceAccountDataStore() {
return getOrCreate(() -> signalProtocolStore, return getOrCreate(() -> signalProtocolStore,
() -> signalProtocolStore = new SignalProtocolStore(getPreKeyStore(), () -> signalProtocolStore = new SignalProtocolStore(getAciPreKeyStore(),
getSignedPreKeyStore(), getAciSignedPreKeyStore(),
getSessionStore(), getSessionStore(),
getIdentityKeyStore(), getIdentityKeyStore(),
getSenderKeyStore(), getSenderKeyStore(),
this::isMultiDevice)); this::isMultiDevice));
} }
private PreKeyStore getPreKeyStore() { private PreKeyStore getAciPreKeyStore() {
return getOrCreate(() -> preKeyStore, return getOrCreate(() -> aciPreKeyStore,
() -> preKeyStore = new PreKeyStore(getPreKeysPath(dataPath, accountPath))); () -> aciPreKeyStore = new PreKeyStore(getAciPreKeysPath(dataPath, accountPath)));
} }
private SignedPreKeyStore getSignedPreKeyStore() { private SignedPreKeyStore getAciSignedPreKeyStore() {
return getOrCreate(() -> signedPreKeyStore, return getOrCreate(() -> aciSignedPreKeyStore,
() -> signedPreKeyStore = new SignedPreKeyStore(getSignedPreKeysPath(dataPath, accountPath))); () -> aciSignedPreKeyStore = new SignedPreKeyStore(getAciSignedPreKeysPath(dataPath, accountPath)));
}
private PreKeyStore getPniPreKeyStore() {
return getOrCreate(() -> pniPreKeyStore,
() -> pniPreKeyStore = new PreKeyStore(getPniPreKeysPath(dataPath, accountPath)));
}
private SignedPreKeyStore getPniSignedPreKeyStore() {
return getOrCreate(() -> pniSignedPreKeyStore,
() -> pniSignedPreKeyStore = new SignedPreKeyStore(getPniSignedPreKeysPath(dataPath, accountPath)));
} }
public SessionStore getSessionStore() { public SessionStore getSessionStore() {
@ -1078,6 +1156,10 @@ public class SignalAccount implements Closeable {
return deviceId == SignalServiceAddress.DEFAULT_DEVICE_ID; return deviceId == SignalServiceAddress.DEFAULT_DEVICE_ID;
} }
public IdentityKeyPair getIdentityKeyPair(ServiceIdType serviceIdType) {
return serviceIdType.equals(ServiceIdType.ACI) ? aciIdentityKeyPair : pniIdentityKeyPair;
}
public IdentityKeyPair getAciIdentityKeyPair() { public IdentityKeyPair getAciIdentityKeyPair() {
return aciIdentityKeyPair; return aciIdentityKeyPair;
} }
@ -1157,12 +1239,12 @@ public class SignalAccount implements Closeable {
return UnidentifiedAccess.deriveAccessKeyFrom(getProfileKey()); return UnidentifiedAccess.deriveAccessKeyFrom(getProfileKey());
} }
public int getPreKeyIdOffset() { public int getPreKeyIdOffset(ServiceIdType serviceIdType) {
return preKeyIdOffset; return serviceIdType.equals(ServiceIdType.ACI) ? aciPreKeyIdOffset : pniPreKeyIdOffset;
} }
public int getNextSignedPreKeyId() { public int getNextSignedPreKeyId(ServiceIdType serviceIdType) {
return nextSignedPreKeyId; return serviceIdType.equals(ServiceIdType.ACI) ? aciNextSignedPreKeyId : pniNextSignedPreKeyId;
} }
public boolean isRegistered() { public boolean isRegistered() {