Refresh pre keys for PNI identity

Fixes #930
This commit is contained in:
AsamK 2022-04-11 20:05:02 +02:00
parent 2a20e70aab
commit 945ff44de3
2 changed files with 132 additions and 51 deletions

View file

@ -45,32 +45,31 @@ public class PreKeyHelper {
}
public void refreshPreKeys(ServiceIdType serviceIdType) throws IOException {
if (serviceIdType != ServiceIdType.ACI) {
// TODO implement
final var oneTimePreKeys = generatePreKeys(serviceIdType);
final var identityKeyPair = account.getIdentityKeyPair(serviceIdType);
if (identityKeyPair == null) {
return;
}
var oneTimePreKeys = generatePreKeys();
final var identityKeyPair = account.getAciIdentityKeyPair();
var signedPreKeyRecord = generateSignedPreKey(identityKeyPair);
final var signedPreKeyRecord = generateSignedPreKey(serviceIdType, identityKeyPair);
dependencies.getAccountManager()
.setPreKeys(serviceIdType, identityKeyPair.getPublicKey(), signedPreKeyRecord, oneTimePreKeys);
}
private List<PreKeyRecord> generatePreKeys() {
final var offset = account.getPreKeyIdOffset();
private List<PreKeyRecord> generatePreKeys(ServiceIdType serviceIdType) {
final var offset = account.getPreKeyIdOffset(serviceIdType);
var records = KeyUtils.generatePreKeyRecords(offset, ServiceConfig.PREKEY_BATCH_SIZE);
account.addPreKeys(records);
account.addPreKeys(serviceIdType, records);
return records;
}
private SignedPreKeyRecord generateSignedPreKey(IdentityKeyPair identityKeyPair) {
final var signedPreKeyId = account.getNextSignedPreKeyId();
private SignedPreKeyRecord generateSignedPreKey(ServiceIdType serviceIdType, IdentityKeyPair identityKeyPair) {
final var signedPreKeyId = account.getNextSignedPreKeyId(serviceIdType);
var record = KeyUtils.generateSignedPreKeyRecord(identityKeyPair, signedPreKeyId);
account.addSignedPreKey(record);
account.addSignedPreKey(serviceIdType, record);
return record;
}

View file

@ -53,6 +53,7 @@ import org.whispersystems.signalservice.api.push.ACI;
import org.whispersystems.signalservice.api.push.DistributionId;
import org.whispersystems.signalservice.api.push.PNI;
import org.whispersystems.signalservice.api.push.ServiceId;
import org.whispersystems.signalservice.api.push.ServiceIdType;
import org.whispersystems.signalservice.api.push.SignalServiceAddress;
import org.whispersystems.signalservice.api.storage.StorageKey;
import org.whispersystems.signalservice.api.util.CredentialsProvider;
@ -106,8 +107,10 @@ public class SignalAccount implements Closeable {
private StorageKey storageKey;
private long storageManifestVersion = -1;
private ProfileKey profileKey;
private int preKeyIdOffset = 1;
private int nextSignedPreKeyId = 1;
private int aciPreKeyIdOffset = 1;
private int aciNextSignedPreKeyId = 1;
private int pniPreKeyIdOffset = 1;
private int pniNextSignedPreKeyId = 1;
private IdentityKeyPair aciIdentityKeyPair;
private IdentityKeyPair pniIdentityKeyPair;
private int localRegistrationId;
@ -117,8 +120,10 @@ public class SignalAccount implements Closeable {
private boolean registered = false;
private SignalProtocolStore signalProtocolStore;
private PreKeyStore preKeyStore;
private SignedPreKeyStore signedPreKeyStore;
private PreKeyStore aciPreKeyStore;
private SignedPreKeyStore aciSignedPreKeyStore;
private PreKeyStore pniPreKeyStore;
private SignedPreKeyStore pniSignedPreKeyStore;
private SessionStore sessionStore;
private IdentityKeyStore identityKeyStore;
private SenderKeyStore senderKeyStore;
@ -259,10 +264,14 @@ public class SignalAccount implements Closeable {
}
private void clearAllPreKeys() {
this.preKeyIdOffset = new SecureRandom().nextInt(Medium.MAX_VALUE);
this.nextSignedPreKeyId = new SecureRandom().nextInt(Medium.MAX_VALUE);
this.getPreKeyStore().removeAllPreKeys();
this.getSignedPreKeyStore().removeAllSignedPreKeys();
this.aciPreKeyIdOffset = new SecureRandom().nextInt(Medium.MAX_VALUE);
this.aciNextSignedPreKeyId = new SecureRandom().nextInt(Medium.MAX_VALUE);
this.pniPreKeyIdOffset = new SecureRandom().nextInt(Medium.MAX_VALUE);
this.pniNextSignedPreKeyId = new SecureRandom().nextInt(Medium.MAX_VALUE);
this.getAciPreKeyStore().removeAllPreKeys();
this.getAciSignedPreKeyStore().removeAllSignedPreKeys();
this.getPniPreKeyStore().removeAllPreKeys();
this.getPniSignedPreKeyStore().removeAllSignedPreKeys();
save();
}
@ -407,14 +416,22 @@ public class SignalAccount implements Closeable {
return new File(getUserPath(dataPath, account), "group-cache");
}
private static File getPreKeysPath(File dataPath, String account) {
private static File getAciPreKeysPath(File dataPath, String account) {
return new File(getUserPath(dataPath, account), "pre-keys");
}
private static File getSignedPreKeysPath(File dataPath, String account) {
private static File getAciSignedPreKeysPath(File dataPath, String account) {
return new File(getUserPath(dataPath, account), "signed-pre-keys");
}
private static File getPniPreKeysPath(File dataPath, String account) {
return new File(getUserPath(dataPath, account), "pre-keys-pni");
}
private static File getPniSignedPreKeysPath(File dataPath, String account) {
return new File(getUserPath(dataPath, account), "signed-pre-keys-pni");
}
private static File getIdentitiesPath(File dataPath, String account) {
return new File(getUserPath(dataPath, account), "identities");
}
@ -528,14 +545,24 @@ public class SignalAccount implements Closeable {
storageManifestVersion = rootNode.get("storageManifestVersion").asLong();
}
if (rootNode.hasNonNull("preKeyIdOffset")) {
preKeyIdOffset = rootNode.get("preKeyIdOffset").asInt(1);
aciPreKeyIdOffset = rootNode.get("preKeyIdOffset").asInt(1);
} else {
preKeyIdOffset = 1;
aciPreKeyIdOffset = 1;
}
if (rootNode.hasNonNull("nextSignedPreKeyId")) {
nextSignedPreKeyId = rootNode.get("nextSignedPreKeyId").asInt(1);
aciNextSignedPreKeyId = rootNode.get("nextSignedPreKeyId").asInt(1);
} else {
nextSignedPreKeyId = 1;
aciNextSignedPreKeyId = 1;
}
if (rootNode.hasNonNull("pniPreKeyIdOffset")) {
pniPreKeyIdOffset = rootNode.get("pniPreKeyIdOffset").asInt(1);
} else {
pniPreKeyIdOffset = 1;
}
if (rootNode.hasNonNull("pniNextSignedPreKeyId")) {
pniNextSignedPreKeyId = rootNode.get("pniNextSignedPreKeyId").asInt(1);
} else {
pniNextSignedPreKeyId = 1;
}
if (rootNode.hasNonNull("profileKey")) {
try {
@ -618,7 +645,7 @@ public class SignalAccount implements Closeable {
logger.debug("Migrating legacy pre key store.");
for (var entry : legacySignalProtocolStore.getLegacyPreKeyStore().getPreKeys().entrySet()) {
try {
getPreKeyStore().storePreKey(entry.getKey(), new PreKeyRecord(entry.getValue()));
getAciPreKeyStore().storePreKey(entry.getKey(), new PreKeyRecord(entry.getValue()));
} catch (InvalidMessageException e) {
logger.warn("Failed to migrate pre key, ignoring", e);
}
@ -630,7 +657,8 @@ public class SignalAccount implements Closeable {
logger.debug("Migrating legacy signed pre key store.");
for (var entry : legacySignalProtocolStore.getLegacySignedPreKeyStore().getSignedPreKeys().entrySet()) {
try {
getSignedPreKeyStore().storeSignedPreKey(entry.getKey(), new SignedPreKeyRecord(entry.getValue()));
getAciSignedPreKeyStore().storeSignedPreKey(entry.getKey(),
new SignedPreKeyRecord(entry.getValue()));
} catch (InvalidMessageException e) {
logger.warn("Failed to migrate signed pre key, ignoring", e);
}
@ -813,8 +841,10 @@ public class SignalAccount implements Closeable {
.put("storageKey",
storageKey == null ? null : Base64.getEncoder().encodeToString(storageKey.serialize()))
.put("storageManifestVersion", storageManifestVersion == -1 ? null : storageManifestVersion)
.put("preKeyIdOffset", preKeyIdOffset)
.put("nextSignedPreKeyId", nextSignedPreKeyId)
.put("preKeyIdOffset", aciPreKeyIdOffset)
.put("nextSignedPreKeyId", aciNextSignedPreKeyId)
.put("pniPreKeyIdOffset", pniPreKeyIdOffset)
.put("pniNextSignedPreKeyId", pniNextSignedPreKeyId)
.put("profileKey",
profileKey == null ? null : Base64.getEncoder().encodeToString(profileKey.serialize()))
.put("registered", registered)
@ -852,25 +882,63 @@ public class SignalAccount implements Closeable {
return new Pair<>(fileChannel, lock);
}
public void addPreKeys(List<PreKeyRecord> records) {
public void addPreKeys(ServiceIdType serviceIdType, List<PreKeyRecord> records) {
if (serviceIdType.equals(ServiceIdType.ACI)) {
addAciPreKeys(records);
} else {
addPniPreKeys(records);
}
}
public void addAciPreKeys(List<PreKeyRecord> records) {
for (var record : records) {
if (preKeyIdOffset != record.getId()) {
logger.error("Invalid pre key id {}, expected {}", record.getId(), preKeyIdOffset);
if (aciPreKeyIdOffset != record.getId()) {
logger.error("Invalid pre key id {}, expected {}", record.getId(), aciPreKeyIdOffset);
throw new AssertionError("Invalid pre key id");
}
getPreKeyStore().storePreKey(record.getId(), record);
preKeyIdOffset = (preKeyIdOffset + 1) % Medium.MAX_VALUE;
getAciPreKeyStore().storePreKey(record.getId(), record);
aciPreKeyIdOffset = (aciPreKeyIdOffset + 1) % Medium.MAX_VALUE;
}
save();
}
public void addSignedPreKey(SignedPreKeyRecord record) {
if (nextSignedPreKeyId != record.getId()) {
logger.error("Invalid signed pre key id {}, expected {}", record.getId(), nextSignedPreKeyId);
public void addPniPreKeys(List<PreKeyRecord> records) {
for (var record : records) {
if (pniPreKeyIdOffset != record.getId()) {
logger.error("Invalid pre key id {}, expected {}", record.getId(), pniPreKeyIdOffset);
throw new AssertionError("Invalid pre key id");
}
getPniPreKeyStore().storePreKey(record.getId(), record);
pniPreKeyIdOffset = (pniPreKeyIdOffset + 1) % Medium.MAX_VALUE;
}
save();
}
public void addSignedPreKey(ServiceIdType serviceIdType, SignedPreKeyRecord record) {
if (serviceIdType.equals(ServiceIdType.ACI)) {
addAciSignedPreKey(record);
} else {
addPniSignedPreKey(record);
}
}
public void addAciSignedPreKey(SignedPreKeyRecord record) {
if (aciNextSignedPreKeyId != record.getId()) {
logger.error("Invalid signed pre key id {}, expected {}", record.getId(), aciNextSignedPreKeyId);
throw new AssertionError("Invalid signed pre key id");
}
getSignedPreKeyStore().storeSignedPreKey(record.getId(), record);
nextSignedPreKeyId = (nextSignedPreKeyId + 1) % Medium.MAX_VALUE;
getAciSignedPreKeyStore().storeSignedPreKey(record.getId(), record);
aciNextSignedPreKeyId = (aciNextSignedPreKeyId + 1) % Medium.MAX_VALUE;
save();
}
public void addPniSignedPreKey(SignedPreKeyRecord record) {
if (pniNextSignedPreKeyId != record.getId()) {
logger.error("Invalid signed pre key id {}, expected {}", record.getId(), pniNextSignedPreKeyId);
throw new AssertionError("Invalid signed pre key id");
}
getPniSignedPreKeyStore().storeSignedPreKey(record.getId(), record);
pniNextSignedPreKeyId = (pniNextSignedPreKeyId + 1) % Medium.MAX_VALUE;
save();
}
@ -906,22 +974,32 @@ public class SignalAccount implements Closeable {
public SignalServiceAccountDataStore getSignalServiceAccountDataStore() {
return getOrCreate(() -> signalProtocolStore,
() -> signalProtocolStore = new SignalProtocolStore(getPreKeyStore(),
getSignedPreKeyStore(),
() -> signalProtocolStore = new SignalProtocolStore(getAciPreKeyStore(),
getAciSignedPreKeyStore(),
getSessionStore(),
getIdentityKeyStore(),
getSenderKeyStore(),
this::isMultiDevice));
}
private PreKeyStore getPreKeyStore() {
return getOrCreate(() -> preKeyStore,
() -> preKeyStore = new PreKeyStore(getPreKeysPath(dataPath, accountPath)));
private PreKeyStore getAciPreKeyStore() {
return getOrCreate(() -> aciPreKeyStore,
() -> aciPreKeyStore = new PreKeyStore(getAciPreKeysPath(dataPath, accountPath)));
}
private SignedPreKeyStore getSignedPreKeyStore() {
return getOrCreate(() -> signedPreKeyStore,
() -> signedPreKeyStore = new SignedPreKeyStore(getSignedPreKeysPath(dataPath, accountPath)));
private SignedPreKeyStore getAciSignedPreKeyStore() {
return getOrCreate(() -> aciSignedPreKeyStore,
() -> aciSignedPreKeyStore = new SignedPreKeyStore(getAciSignedPreKeysPath(dataPath, accountPath)));
}
private PreKeyStore getPniPreKeyStore() {
return getOrCreate(() -> pniPreKeyStore,
() -> pniPreKeyStore = new PreKeyStore(getPniPreKeysPath(dataPath, accountPath)));
}
private SignedPreKeyStore getPniSignedPreKeyStore() {
return getOrCreate(() -> pniSignedPreKeyStore,
() -> pniSignedPreKeyStore = new SignedPreKeyStore(getPniSignedPreKeysPath(dataPath, accountPath)));
}
public SessionStore getSessionStore() {
@ -1078,6 +1156,10 @@ public class SignalAccount implements Closeable {
return deviceId == SignalServiceAddress.DEFAULT_DEVICE_ID;
}
public IdentityKeyPair getIdentityKeyPair(ServiceIdType serviceIdType) {
return serviceIdType.equals(ServiceIdType.ACI) ? aciIdentityKeyPair : pniIdentityKeyPair;
}
public IdentityKeyPair getAciIdentityKeyPair() {
return aciIdentityKeyPair;
}
@ -1157,12 +1239,12 @@ public class SignalAccount implements Closeable {
return UnidentifiedAccess.deriveAccessKeyFrom(getProfileKey());
}
public int getPreKeyIdOffset() {
return preKeyIdOffset;
public int getPreKeyIdOffset(ServiceIdType serviceIdType) {
return serviceIdType.equals(ServiceIdType.ACI) ? aciPreKeyIdOffset : pniPreKeyIdOffset;
}
public int getNextSignedPreKeyId() {
return nextSignedPreKeyId;
public int getNextSignedPreKeyId(ServiceIdType serviceIdType) {
return serviceIdType.equals(ServiceIdType.ACI) ? aciNextSignedPreKeyId : pniNextSignedPreKeyId;
}
public boolean isRegistered() {