Restructure pre key refresh to be more robust

This commit is contained in:
AsamK 2024-02-18 16:32:50 +01:00
parent 91ed49e019
commit 25258db55d

View file

@ -51,6 +51,14 @@ public class PreKeyHelper {
return; return;
} }
if (refreshPreKeysIfNecessary(serviceIdType, identityKeyPair)) {
refreshPreKeysIfNecessary(serviceIdType, identityKeyPair);
}
}
private boolean refreshPreKeysIfNecessary(
final ServiceIdType serviceIdType, final IdentityKeyPair identityKeyPair
) throws IOException {
OneTimePreKeyCounts preKeyCounts; OneTimePreKeyCounts preKeyCounts;
try { try {
preKeyCounts = dependencies.getAccountManager().getPreKeyCounts(serviceIdType); preKeyCounts = dependencies.getAccountManager().getPreKeyCounts(serviceIdType);
@ -59,12 +67,7 @@ public class PreKeyHelper {
preKeyCounts = new OneTimePreKeyCounts(0, 0); preKeyCounts = new OneTimePreKeyCounts(0, 0);
} }
SignedPreKeyRecord signedPreKeyRecord = null;
List<PreKeyRecord> preKeyRecords = null; List<PreKeyRecord> preKeyRecords = null;
KyberPreKeyRecord lastResortKyberPreKeyRecord = null;
List<KyberPreKeyRecord> kyberPreKeyRecords = null;
try {
if (preKeyCounts.getEcCount() < ServiceConfig.PREKEY_MINIMUM_COUNT) { if (preKeyCounts.getEcCount() < ServiceConfig.PREKEY_MINIMUM_COUNT) {
logger.debug("Refreshing {} ec pre keys, because only {} of min {} pre keys remain", logger.debug("Refreshing {} ec pre keys, because only {} of min {} pre keys remain",
serviceIdType, serviceIdType,
@ -72,18 +75,14 @@ public class PreKeyHelper {
ServiceConfig.PREKEY_MINIMUM_COUNT); ServiceConfig.PREKEY_MINIMUM_COUNT);
preKeyRecords = generatePreKeys(serviceIdType); preKeyRecords = generatePreKeys(serviceIdType);
} }
SignedPreKeyRecord signedPreKeyRecord = null;
if (signedPreKeyNeedsRefresh(serviceIdType)) { if (signedPreKeyNeedsRefresh(serviceIdType)) {
logger.debug("Refreshing {} signed pre key.", serviceIdType); logger.debug("Refreshing {} signed pre key.", serviceIdType);
signedPreKeyRecord = generateSignedPreKey(serviceIdType, identityKeyPair); signedPreKeyRecord = generateSignedPreKey(serviceIdType, identityKeyPair);
} }
} catch (Exception e) {
logger.warn("Failed to store new pre keys, resetting preKey id offset", e);
account.resetPreKeyOffsets(serviceIdType);
preKeyRecords = generatePreKeys(serviceIdType);
signedPreKeyRecord = generateSignedPreKey(serviceIdType, identityKeyPair);
}
try { List<KyberPreKeyRecord> kyberPreKeyRecords = null;
if (preKeyCounts.getKyberCount() < ServiceConfig.PREKEY_MINIMUM_COUNT) { if (preKeyCounts.getKyberCount() < ServiceConfig.PREKEY_MINIMUM_COUNT) {
logger.debug("Refreshing {} kyber pre keys, because only {} of min {} pre keys remain", logger.debug("Refreshing {} kyber pre keys, because only {} of min {} pre keys remain",
serviceIdType, serviceIdType,
@ -91,33 +90,57 @@ public class PreKeyHelper {
ServiceConfig.PREKEY_MINIMUM_COUNT); ServiceConfig.PREKEY_MINIMUM_COUNT);
kyberPreKeyRecords = generateKyberPreKeys(serviceIdType, identityKeyPair); kyberPreKeyRecords = generateKyberPreKeys(serviceIdType, identityKeyPair);
} }
KyberPreKeyRecord lastResortKyberPreKeyRecord = null;
if (lastResortKyberPreKeyNeedsRefresh(serviceIdType)) { if (lastResortKyberPreKeyNeedsRefresh(serviceIdType)) {
logger.debug("Refreshing {} last resort kyber pre key.", serviceIdType); logger.debug("Refreshing {} last resort kyber pre key.", serviceIdType);
lastResortKyberPreKeyRecord = generateLastResortKyberPreKey(serviceIdType, identityKeyPair); lastResortKyberPreKeyRecord = generateLastResortKyberPreKey(serviceIdType, identityKeyPair);
} }
} catch (Exception e) {
logger.warn("Failed to store new kyber pre keys, resetting preKey id offset", e); if (signedPreKeyRecord == null
account.resetKyberPreKeyOffsets(serviceIdType); && preKeyRecords == null
kyberPreKeyRecords = generateKyberPreKeys(serviceIdType, identityKeyPair); && lastResortKyberPreKeyRecord == null
lastResortKyberPreKeyRecord = generateLastResortKyberPreKey(serviceIdType, identityKeyPair); && kyberPreKeyRecords == null) {
return false;
} }
if (signedPreKeyRecord != null
|| preKeyRecords != null
|| lastResortKyberPreKeyRecord != null
|| kyberPreKeyRecords != null) {
final var preKeyUpload = new PreKeyUpload(serviceIdType, final var preKeyUpload = new PreKeyUpload(serviceIdType,
signedPreKeyRecord, signedPreKeyRecord,
preKeyRecords, preKeyRecords,
lastResortKyberPreKeyRecord, lastResortKyberPreKeyRecord,
kyberPreKeyRecords); kyberPreKeyRecords);
var needsReset = false;
try { try {
dependencies.getAccountManager().setPreKeys(preKeyUpload); dependencies.getAccountManager().setPreKeys(preKeyUpload);
try {
if (preKeyRecords != null) {
account.addPreKeys(serviceIdType, preKeyRecords);
}
if (signedPreKeyRecord != null) {
account.addSignedPreKey(serviceIdType, signedPreKeyRecord);
}
} catch (Exception e) {
logger.warn("Failed to store new pre keys, resetting preKey id offset", e);
account.resetPreKeyOffsets(serviceIdType);
needsReset = true;
}
try {
if (kyberPreKeyRecords != null) {
account.addKyberPreKeys(serviceIdType, kyberPreKeyRecords);
}
if (lastResortKyberPreKeyRecord != null) {
account.addLastResortKyberPreKey(serviceIdType, lastResortKyberPreKeyRecord);
}
} catch (Exception e) {
logger.warn("Failed to store new kyber pre keys, resetting preKey id offset", e);
account.resetKyberPreKeyOffsets(serviceIdType);
needsReset = true;
}
} catch (AuthorizationFailedException e) { } catch (AuthorizationFailedException e) {
// This can happen when the primary device has changed phone number // This can happen when the primary device has changed phone number
logger.warn("Failed to updated pre keys: {}", e.getMessage()); logger.warn("Failed to updated pre keys: {}", e.getMessage());
} }
} return needsReset;
} }
public void cleanOldPreKeys() { public void cleanOldPreKeys() {
@ -135,7 +158,6 @@ public class PreKeyHelper {
final var offset = accountData.getPreKeyMetadata().getNextPreKeyId(); final var offset = accountData.getPreKeyMetadata().getNextPreKeyId();
var records = KeyUtils.generatePreKeyRecords(offset); var records = KeyUtils.generatePreKeyRecords(offset);
account.addPreKeys(serviceIdType, records);
return records; return records;
} }
@ -159,10 +181,7 @@ public class PreKeyHelper {
final var accountData = account.getAccountData(serviceIdType); final var accountData = account.getAccountData(serviceIdType);
final var signedPreKeyId = accountData.getPreKeyMetadata().getNextSignedPreKeyId(); final var signedPreKeyId = accountData.getPreKeyMetadata().getNextSignedPreKeyId();
var record = KeyUtils.generateSignedPreKeyRecord(signedPreKeyId, identityKeyPair.getPrivateKey()); return KeyUtils.generateSignedPreKeyRecord(signedPreKeyId, identityKeyPair.getPrivateKey());
account.addSignedPreKey(serviceIdType, record);
return record;
} }
private List<KyberPreKeyRecord> generateKyberPreKeys( private List<KyberPreKeyRecord> generateKyberPreKeys(
@ -171,10 +190,7 @@ public class PreKeyHelper {
final var accountData = account.getAccountData(serviceIdType); final var accountData = account.getAccountData(serviceIdType);
final var offset = accountData.getPreKeyMetadata().getNextKyberPreKeyId(); final var offset = accountData.getPreKeyMetadata().getNextKyberPreKeyId();
var records = KeyUtils.generateKyberPreKeyRecords(offset, identityKeyPair.getPrivateKey()); return KeyUtils.generateKyberPreKeyRecords(offset, identityKeyPair.getPrivateKey());
account.addKyberPreKeys(serviceIdType, records);
return records;
} }
private boolean lastResortKyberPreKeyNeedsRefresh(ServiceIdType serviceIdType) { private boolean lastResortKyberPreKeyNeedsRefresh(ServiceIdType serviceIdType) {
@ -199,10 +215,7 @@ public class PreKeyHelper {
final var accountData = account.getAccountData(serviceIdType); final var accountData = account.getAccountData(serviceIdType);
final var signedPreKeyId = accountData.getPreKeyMetadata().getNextKyberPreKeyId(); final var signedPreKeyId = accountData.getPreKeyMetadata().getNextKyberPreKeyId();
var record = KeyUtils.generateKyberPreKeyRecord(signedPreKeyId, identityKeyPair.getPrivateKey()); return KeyUtils.generateKyberPreKeyRecord(signedPreKeyId, identityKeyPair.getPrivateKey());
account.addLastResortKyberPreKey(serviceIdType, record);
return record;
} }
private void cleanSignedPreKeys(ServiceIdType serviceIdType) { private void cleanSignedPreKeys(ServiceIdType serviceIdType) {