Salsa: LRU needs keep dependency info

reivew #371 加深了我对 salsa 做 recomputation 的理解。

我们先来看一个测试,思考这个测试能否 pass

#[salsa::tracked(jar = Jar, lru = 3)]
fn get_hot_potato(db: &dyn Db, input: MyInput) -> Arc<HotPotato> {
    db.push_log(format!("get_hot_potato({:?})", input.field(db)));
    Arc::new(HotPotato::new(input.field(db)))
}

#[salsa::tracked(jar = Jar)]
fn get_hot_potato2(db: &dyn Db, input: MyInput) -> u32 {
    db.push_log(format!("get_hot_potato2({:?})", input.field(db)));
    get_hot_potato(db, input).0
}

#[test]
fn lru_keeps_dependency_info() {
    let mut db = Database::default();
    let capacity = 3;

    // Invoke `get_hot_potato2` 33 times. This will (in turn) invoke
    // `get_hot_potato`, which will trigger LRU after 32 executions.
    let inputs: Vec<MyInput> = (0..(capacity + 1))
        .map(|i| MyInput::new(&mut db, i as u32))
        .collect();

    for (i, input) in inputs.iter().enumerate() {
        let x = get_hot_potato2(&db, *input);
        assert_eq!(x as usize, i);
    }

    db.salsa_runtime_mut()
        .synthetic_write(salsa::Durability::HIGH);

    // We want to test that calls to `get_hot_potato2` are still considered
    // clean. Check that no new executions occur as we go here.
    db.assert_logs_len((capacity + 1) * 2);

    // calling `get_hot_potato2(0)` has to check that `get_hot_potato(0)` is still valid;
    // even though we've evicted it (LRU), we find that it is still good
    let p = get_hot_potato2(&db, *inputs.first().unwrap());
    assert_eq!(p, 0);
    db.assert_logs_len(0);
}

这个测试是要做什么?

为了方便,将 get_hot_potato 记为 getget_hot_potato2 记为 get2

我们有两个 tracked function:getget2,不用关心这两个 function 做了啥,只需要知道 get2 调用了 get,也就是说 get2 依赖 get 的结果。另一个需要注意的点是 get 设置了 lru = 3,而 get2 没有设置 lru。

现在调用 get2 四次,来看看 salsa 内部的状态,get 并没有存储 input 为 0 的结果,因为它设置了 lru = 3,input 为 0 的结果被 evict 了。

+-------+----------------+-----------------+
| input |       get      |       get2      |
+-------+----------------+-----------------+
|   0   |                |        0        |
+-------+----------------+-----------------+
|   1   |        1       |        1        |
+-------+----------------+-----------------+
|   2   |        2       |        2        |
+-------+----------------+-----------------+
|   3   |        3       |        3        |
+-------+----------------+-----------------+

我们现在调用 get2(0)(这里应该写 input0 会比较好,input0 = MyInput::new(&mut db, 0),但为了简单就写了 0), 能直接用 salsa 存储的结果吗?还需要重新计算吗?显然不用,不是存储有吗?那如果现在有别的输入改变了 salsa 的 current_revision 呢? synthetic_write 就是在做这件事。也就是说 salsa 的 current_revision 大于 getget2 结果被 verify 的 revision。这意味着 shallow_verify_memo 不能确定 get2(0) 的结果是否能用,得让 deep_verify_memo 去进一步判断。deep_verify_memo 会检查 get2(0) 依赖的其他计算结果有没有改变,所以会去查看 get(0), 发现压根儿没有存储 get(0) 的结果,当然也没法判断它有没有改变,只能保守地认为发生了改变,所以 get2(0)get(0) 都会重新计算。

显然,这不够好,实际上 get2(0) 的结果是可以重用的,因为 get(0) 的结果虽然被 evict 了,但并没有发生改变。 在我们这个场景下,get2(0) 并不关心 get(0) 的结果是多少,只关心从它上次被 verify 后有没有发生改变

有没有方法去优化它呢?#371 提供了一种解决方案。我们 evict 的时候,不再直接删除整个 Memo,只是把 Memo.value 设置为 None,保留 verified_atrevisions

/// Evicts the existing memo for the given key, replacing it
/// with an equivalent memo that has no value. If the memo is untracked, BaseInput,
/// or has values assigned as output of another query, this has no effect.
pub(super) fn evict(&self, key: K) {
    use dashmap::mapref::entry::Entry::*;
    use crate::runtime::local_state::QueryOrigin;

    if let Occupied(entry )=  self.map.entry(key) {
        let memo = entry.get().load();
        match memo.revisions.origin {
            QueryOrigin::Assigned(_)
            | QueryOrigin::DerivedUntracked(_)
            | QueryOrigin::BaseInput
            | QueryOrigin::Field => {
                // Careful: Cannot evict memos whose values were
                // assigned as output of another query
                // or those with untracked inputs
                // as their values cannot be reconstructed.
                return;
            },

            QueryOrigin::Derived(_) => {
                let memo_evicted = Arc::new(Memo::new(
                    None::<V>,
                    memo.verified_at.load(),
                    memo.revisions.clone(),
                ));

                entry.get().store(memo_evicted);
            }
        }
    }
}