aether_runtime/
lib.rs

1use aether_ast::{
2    merge_partition_cuts, merge_policy_envelopes, policy_allows, AggregateFunction, AggregateTerm,
3    DerivedTuple, DerivedTupleMetadata, ElementId, Literal, PartitionCut, PolicyContext,
4    PolicyEnvelope, PredicateId, QueryAst, QueryResult, QueryRow, RuleAst, RuleId, Term, Tuple,
5    TupleId, Value, Variable,
6};
7use aether_plan::CompiledProgram;
8use aether_resolver::ResolvedState;
9use indexmap::{IndexMap, IndexSet};
10use serde::{Deserialize, Serialize};
11use std::cmp::Ordering;
12use thiserror::Error;
13
14pub trait RuleRuntime {
15    fn evaluate(
16        &self,
17        state: &ResolvedState,
18        program: &CompiledProgram,
19    ) -> Result<DerivedSet, RuntimeError>;
20}
21
22#[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize)]
23pub struct RuntimeIteration {
24    pub iteration: usize,
25    pub delta_size: usize,
26}
27
28#[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize)]
29pub struct DerivedSet {
30    pub tuples: Vec<DerivedTuple>,
31    pub iterations: Vec<RuntimeIteration>,
32    pub predicate_index: IndexMap<PredicateId, Vec<TupleId>>,
33}
34
35impl DerivedSet {
36    pub fn has_converged(&self) -> bool {
37        match self.iterations.last() {
38            Some(iteration) => iteration.delta_size == 0,
39            None => true,
40        }
41    }
42}
43
44#[derive(Clone, Debug, Default)]
45struct RelationRow {
46    values: Vec<Value>,
47    tuple_id: Option<TupleId>,
48    source_datom_ids: Vec<ElementId>,
49    imported_cuts: Vec<PartitionCut>,
50    policy: Option<PolicyEnvelope>,
51}
52
53#[derive(Clone, Debug, Default)]
54struct MatchState {
55    bindings: IndexMap<Variable, Value>,
56    parent_tuple_ids: Vec<TupleId>,
57    source_datom_ids: Vec<ElementId>,
58    imported_cuts: Vec<PartitionCut>,
59    query_tuple_id: Option<TupleId>,
60    policy: Option<PolicyEnvelope>,
61}
62
63#[derive(Clone, Debug)]
64struct AggregatedMatch {
65    values: Vec<Value>,
66    parent_tuple_ids: Vec<TupleId>,
67    source_datom_ids: Vec<ElementId>,
68    imported_cuts: Vec<PartitionCut>,
69    policy: Option<PolicyEnvelope>,
70}
71
72#[derive(Clone, Debug)]
73struct AggregateGroup {
74    values: Vec<Option<Value>>,
75    accumulators: Vec<AggregateAccumulator>,
76    seen_bindings: IndexSet<String>,
77    parent_tuple_ids: Vec<TupleId>,
78    source_datom_ids: Vec<ElementId>,
79    imported_cuts: Vec<PartitionCut>,
80    policy: Option<PolicyEnvelope>,
81}
82
83#[derive(Clone, Debug)]
84enum AggregateAccumulator {
85    Count(u64),
86    SumI64(i64),
87    SumU64(u64),
88    SumF64(f64),
89    Min(Value),
90    Max(Value),
91}
92
93#[derive(Default)]
94pub struct SemiNaiveRuntime;
95
96impl RuleRuntime for SemiNaiveRuntime {
97    fn evaluate(
98        &self,
99        state: &ResolvedState,
100        program: &CompiledProgram,
101    ) -> Result<DerivedSet, RuntimeError> {
102        let extensional_rows = build_extensional_rows(state, program);
103        let intensional_predicates: IndexSet<PredicateId> = program
104            .rules
105            .iter()
106            .map(|rule| rule.head.predicate.id)
107            .collect();
108        let scc_lookup = build_scc_lookup(program);
109        let scc_order = build_scc_evaluation_order(program, &scc_lookup);
110        let rules_by_scc = build_rules_by_scc(program, &scc_lookup);
111
112        let mut derived_by_predicate: IndexMap<PredicateId, Vec<RelationRow>> = IndexMap::new();
113        let mut tuple_keys = IndexSet::new();
114        let mut tuples = Vec::new();
115        let mut iterations = Vec::new();
116        let mut next_tuple_id = 1u64;
117        let mut iteration = 1usize;
118
119        for scc_id in scc_order {
120            let Some(rules) = rules_by_scc.get(&scc_id) else {
121                continue;
122            };
123            let current_scc_predicates: IndexSet<PredicateId> =
124                rules.iter().map(|rule| rule.head.predicate.id).collect();
125            let stratum = rules
126                .first()
127                .and_then(|rule| program.predicate_strata.get(&rule.head.predicate.id))
128                .copied()
129                .unwrap_or_default();
130
131            let mut delta_rows: IndexMap<PredicateId, Vec<RelationRow>> = IndexMap::new();
132            loop {
133                let mut batch_rows: IndexMap<PredicateId, Vec<RelationRow>> = IndexMap::new();
134                let mut batch_tuples = Vec::new();
135
136                for rule in rules {
137                    let aggregates = head_aggregates(rule);
138                    let anchor_indices = if aggregates.is_empty() {
139                        current_scc_positive_indices(rule, &current_scc_predicates)
140                    } else {
141                        Vec::new()
142                    };
143                    let anchor_plan = if delta_rows.is_empty() {
144                        if anchor_indices.is_empty() {
145                            vec![None]
146                        } else {
147                            Vec::new()
148                        }
149                    } else if anchor_indices.is_empty() {
150                        Vec::new()
151                    } else {
152                        anchor_indices.into_iter().map(Some).collect()
153                    };
154
155                    let mut aggregate_matches = Vec::new();
156
157                    for anchor_index in anchor_plan {
158                        let matches = evaluate_rule_body_variant(
159                            rule,
160                            anchor_index,
161                            &derived_by_predicate,
162                            &delta_rows,
163                            &extensional_rows,
164                            &intensional_predicates,
165                            &current_scc_predicates,
166                        )?;
167
168                        if !aggregates.is_empty() {
169                            aggregate_matches.extend(matches);
170                            continue;
171                        }
172
173                        for matched in matches {
174                            let values = materialize_non_aggregate_head(
175                                rule.id,
176                                &rule.head.terms,
177                                &matched.bindings,
178                            )?;
179                            let key = tuple_key(rule.head.predicate.id, &values);
180                            if tuple_keys.contains(&key) {
181                                continue;
182                            }
183
184                            let tuple_id = TupleId::new(next_tuple_id);
185                            next_tuple_id += 1;
186                            tuple_keys.insert(key);
187
188                            batch_rows.entry(rule.head.predicate.id).or_default().push(
189                                RelationRow {
190                                    values: values.clone(),
191                                    tuple_id: Some(tuple_id),
192                                    source_datom_ids: matched.source_datom_ids.clone(),
193                                    imported_cuts: matched.imported_cuts.clone(),
194                                    policy: matched.policy.clone(),
195                                },
196                            );
197                            batch_tuples.push(DerivedTuple {
198                                tuple: Tuple {
199                                    id: tuple_id,
200                                    predicate: rule.head.predicate.id,
201                                    values,
202                                },
203                                metadata: DerivedTupleMetadata {
204                                    rule_id: rule.id,
205                                    predicate_id: rule.head.predicate.id,
206                                    stratum,
207                                    scc_id,
208                                    iteration,
209                                    parent_tuple_ids: matched.parent_tuple_ids,
210                                    source_datom_ids: matched.source_datom_ids,
211                                    imported_cuts: matched.imported_cuts,
212                                },
213                                policy: matched.policy,
214                            });
215                        }
216                    }
217
218                    if !aggregates.is_empty() {
219                        let matches = materialize_aggregate_head(
220                            rule.id,
221                            &rule.head.terms,
222                            &aggregates,
223                            &aggregate_matches,
224                        )?;
225                        for matched in matches {
226                            let key = tuple_key(rule.head.predicate.id, &matched.values);
227                            if tuple_keys.contains(&key) {
228                                continue;
229                            }
230
231                            let tuple_id = TupleId::new(next_tuple_id);
232                            next_tuple_id += 1;
233                            tuple_keys.insert(key);
234
235                            batch_rows.entry(rule.head.predicate.id).or_default().push(
236                                RelationRow {
237                                    values: matched.values.clone(),
238                                    tuple_id: Some(tuple_id),
239                                    source_datom_ids: matched.source_datom_ids.clone(),
240                                    imported_cuts: matched.imported_cuts.clone(),
241                                    policy: matched.policy.clone(),
242                                },
243                            );
244                            batch_tuples.push(DerivedTuple {
245                                tuple: Tuple {
246                                    id: tuple_id,
247                                    predicate: rule.head.predicate.id,
248                                    values: matched.values,
249                                },
250                                metadata: DerivedTupleMetadata {
251                                    rule_id: rule.id,
252                                    predicate_id: rule.head.predicate.id,
253                                    stratum,
254                                    scc_id,
255                                    iteration,
256                                    parent_tuple_ids: matched.parent_tuple_ids,
257                                    source_datom_ids: matched.source_datom_ids,
258                                    imported_cuts: matched.imported_cuts,
259                                },
260                                policy: matched.policy,
261                            });
262                        }
263                    }
264                }
265
266                if batch_tuples.is_empty() {
267                    break;
268                }
269
270                iterations.push(RuntimeIteration {
271                    iteration,
272                    delta_size: batch_tuples.len(),
273                });
274                iteration += 1;
275
276                for (predicate, rows) in &batch_rows {
277                    derived_by_predicate
278                        .entry(*predicate)
279                        .or_default()
280                        .extend(rows.iter().cloned());
281                }
282                tuples.extend(batch_tuples);
283                delta_rows = batch_rows;
284            }
285        }
286
287        iterations.push(RuntimeIteration {
288            iteration,
289            delta_size: 0,
290        });
291
292        let mut predicate_index = program
293            .materialized
294            .iter()
295            .copied()
296            .map(|predicate| (predicate, Vec::new()))
297            .collect::<IndexMap<_, _>>();
298        for tuple in &tuples {
299            predicate_index
300                .entry(tuple.tuple.predicate)
301                .or_default()
302                .push(tuple.tuple.id);
303        }
304
305        Ok(DerivedSet {
306            tuples,
307            iterations,
308            predicate_index,
309        })
310    }
311}
312
313pub fn execute_query(
314    state: &ResolvedState,
315    program: &CompiledProgram,
316    derived: &DerivedSet,
317    query: &QueryAst,
318    policy_context: Option<&PolicyContext>,
319) -> Result<QueryResult, RuntimeError> {
320    let extensional_rows = build_extensional_rows(state, program);
321    let intensional_predicates: IndexSet<PredicateId> = program
322        .rules
323        .iter()
324        .map(|rule| rule.head.predicate.id)
325        .collect();
326    let derived_rows = build_derived_rows(derived);
327
328    let mut states = vec![MatchState::default()];
329    for goal in &query.goals {
330        let rows = positive_relation_rows(
331            goal.predicate.id,
332            None,
333            &derived_rows,
334            &IndexMap::new(),
335            &extensional_rows,
336            &intensional_predicates,
337            &IndexSet::new(),
338        )?;
339        let mut next_states = Vec::new();
340
341        for state in &states {
342            for row in &rows {
343                if !policy_allows(policy_context, row.policy.as_ref()) {
344                    continue;
345                }
346                if let Some(bindings) = unify_terms(&state.bindings, &goal.terms, &row.values) {
347                    next_states.push(MatchState {
348                        bindings,
349                        parent_tuple_ids: state.parent_tuple_ids.clone(),
350                        source_datom_ids: state.source_datom_ids.clone(),
351                        imported_cuts: merge_partition_cuts(
352                            state.imported_cuts.iter().chain(row.imported_cuts.iter()),
353                        ),
354                        query_tuple_id: row.tuple_id.or(state.query_tuple_id),
355                        policy: merge_policy_envelopes([
356                            state.policy.as_ref(),
357                            row.policy.as_ref(),
358                        ]),
359                    });
360                }
361            }
362        }
363
364        states = next_states;
365        if states.is_empty() {
366            break;
367        }
368    }
369
370    let mut rows = states
371        .into_iter()
372        .filter(|state| policy_allows(policy_context, state.policy.as_ref()))
373        .map(|state| QueryRow {
374            values: if query.keep.is_empty() {
375                state.bindings.values().cloned().collect()
376            } else {
377                query
378                    .keep
379                    .iter()
380                    .filter_map(|variable| state.bindings.get(variable).cloned())
381                    .collect()
382            },
383            tuple_id: state.query_tuple_id,
384        })
385        .collect::<Vec<_>>();
386    rows.sort_by_key(|row| {
387        let mut key = String::new();
388        for value in &row.values {
389            key.push_str(&value_key(value));
390            key.push('|');
391        }
392        key
393    });
394
395    Ok(QueryResult { rows })
396}
397
398fn build_extensional_rows(
399    state: &ResolvedState,
400    program: &CompiledProgram,
401) -> IndexMap<PredicateId, Vec<RelationRow>> {
402    let mut rows: IndexMap<PredicateId, Vec<RelationRow>> = IndexMap::new();
403
404    for (predicate, attribute) in &program.extensional_bindings {
405        let mut predicate_rows = Vec::new();
406        for (entity_id, entity_state) in &state.entities {
407            predicate_rows.extend(entity_state.facts(attribute).iter().cloned().map(|fact| {
408                RelationRow {
409                    values: vec![Value::Entity(*entity_id), fact.value],
410                    tuple_id: None,
411                    source_datom_ids: fact.source_datom_ids,
412                    imported_cuts: Vec::new(),
413                    policy: fact.policy,
414                }
415            }));
416        }
417        rows.entry(*predicate).or_default().extend(predicate_rows);
418    }
419
420    for fact in &program.facts {
421        rows.entry(fact.predicate.id)
422            .or_default()
423            .push(RelationRow {
424                values: fact.values.clone(),
425                tuple_id: None,
426                source_datom_ids: fact
427                    .provenance
428                    .as_ref()
429                    .map(|provenance| provenance.source_datom_ids.clone())
430                    .unwrap_or_default(),
431                imported_cuts: fact
432                    .provenance
433                    .as_ref()
434                    .map(|provenance| provenance.imported_cuts.clone())
435                    .unwrap_or_default(),
436                policy: fact.policy.clone(),
437            });
438    }
439
440    rows
441}
442
443fn build_derived_rows(derived: &DerivedSet) -> IndexMap<PredicateId, Vec<RelationRow>> {
444    let mut rows: IndexMap<PredicateId, Vec<RelationRow>> = IndexMap::new();
445    for tuple in &derived.tuples {
446        rows.entry(tuple.tuple.predicate)
447            .or_default()
448            .push(RelationRow {
449                values: tuple.tuple.values.clone(),
450                tuple_id: Some(tuple.tuple.id),
451                source_datom_ids: tuple.metadata.source_datom_ids.clone(),
452                imported_cuts: tuple.metadata.imported_cuts.clone(),
453                policy: tuple.policy.clone(),
454            });
455    }
456    rows
457}
458
459fn build_rules_by_scc<'a>(
460    program: &'a CompiledProgram,
461    scc_lookup: &IndexMap<PredicateId, usize>,
462) -> IndexMap<usize, Vec<&'a RuleAst>> {
463    let mut rules = IndexMap::new();
464    for rule in &program.rules {
465        let scc_id = *scc_lookup
466            .get(&rule.head.predicate.id)
467            .expect("rule head predicate should be present in scc lookup");
468        rules.entry(scc_id).or_insert_with(Vec::new).push(rule);
469    }
470    rules
471}
472
473fn build_scc_evaluation_order(
474    program: &CompiledProgram,
475    scc_lookup: &IndexMap<PredicateId, usize>,
476) -> Vec<usize> {
477    let mut edges = IndexSet::new();
478    let mut indegree = program
479        .sccs
480        .iter()
481        .map(|scc| (scc.id, 0usize))
482        .collect::<IndexMap<_, _>>();
483    let mut outgoing = program
484        .sccs
485        .iter()
486        .map(|scc| (scc.id, Vec::new()))
487        .collect::<IndexMap<_, _>>();
488
489    for (head, dependencies) in &program.dependency_graph.edges {
490        let head_scc = *scc_lookup
491            .get(head)
492            .expect("head predicate should be present in scc lookup");
493        for dependency in dependencies {
494            let dependency_scc = *scc_lookup
495                .get(dependency)
496                .expect("dependency predicate should be present in scc lookup");
497            if dependency_scc != head_scc && edges.insert((dependency_scc, head_scc)) {
498                outgoing.entry(dependency_scc).or_default().push(head_scc);
499                *indegree.entry(head_scc).or_default() += 1;
500            }
501        }
502    }
503
504    let scc_strata = program
505        .sccs
506        .iter()
507        .map(|scc| {
508            let stratum = scc
509                .predicates
510                .first()
511                .and_then(|predicate| program.predicate_strata.get(predicate))
512                .copied()
513                .unwrap_or_default();
514            (scc.id, stratum)
515        })
516        .collect::<IndexMap<_, _>>();
517    let mut ready = indegree
518        .iter()
519        .filter_map(|(scc_id, degree)| (*degree == 0).then_some(*scc_id))
520        .collect::<Vec<_>>();
521    ready.sort_by_key(|scc_id| (scc_strata.get(scc_id).copied().unwrap_or_default(), *scc_id));
522
523    let mut order = Vec::new();
524    while let Some(scc_id) = ready.first().copied() {
525        ready.remove(0);
526        order.push(scc_id);
527        if let Some(neighbors) = outgoing.get(&scc_id) {
528            for neighbor in neighbors {
529                let degree = indegree
530                    .get_mut(neighbor)
531                    .expect("neighbor scc should have indegree");
532                *degree -= 1;
533                if *degree == 0 {
534                    ready.push(*neighbor);
535                    ready.sort_by_key(|candidate| {
536                        (
537                            scc_strata.get(candidate).copied().unwrap_or_default(),
538                            *candidate,
539                        )
540                    });
541                }
542            }
543        }
544    }
545
546    order
547}
548
549fn current_scc_positive_indices(
550    rule: &RuleAst,
551    current_scc_predicates: &IndexSet<PredicateId>,
552) -> Vec<usize> {
553    rule.body
554        .iter()
555        .enumerate()
556        .filter_map(|(index, literal)| match literal {
557            Literal::Positive(atom) if current_scc_predicates.contains(&atom.predicate.id) => {
558                Some(index)
559            }
560            _ => None,
561        })
562        .collect()
563}
564
565fn evaluate_rule_body_variant(
566    rule: &RuleAst,
567    delta_anchor_index: Option<usize>,
568    derived_rows: &IndexMap<PredicateId, Vec<RelationRow>>,
569    delta_rows: &IndexMap<PredicateId, Vec<RelationRow>>,
570    extensional_rows: &IndexMap<PredicateId, Vec<RelationRow>>,
571    intensional_predicates: &IndexSet<PredicateId>,
572    current_scc_predicates: &IndexSet<PredicateId>,
573) -> Result<Vec<MatchState>, RuntimeError> {
574    let mut states = vec![MatchState::default()];
575
576    for (literal_index, literal) in ordered_rule_body(rule) {
577        match literal {
578            Literal::Positive(atom) => {
579                let rows = positive_relation_rows(
580                    atom.predicate.id,
581                    (delta_anchor_index == Some(literal_index)).then_some(()),
582                    derived_rows,
583                    delta_rows,
584                    extensional_rows,
585                    intensional_predicates,
586                    current_scc_predicates,
587                )?;
588                let mut next_states = Vec::new();
589
590                for state in &states {
591                    for row in &rows {
592                        if let Some(bindings) =
593                            unify_terms(&state.bindings, &atom.terms, &row.values)
594                        {
595                            let mut parent_tuple_ids = state.parent_tuple_ids.clone();
596                            if let Some(tuple_id) = row.tuple_id {
597                                if !parent_tuple_ids.contains(&tuple_id) {
598                                    parent_tuple_ids.push(tuple_id);
599                                }
600                            }
601                            let mut source_datom_ids = state.source_datom_ids.clone();
602                            extend_unique(&mut source_datom_ids, &row.source_datom_ids);
603                            next_states.push(MatchState {
604                                bindings,
605                                parent_tuple_ids,
606                                source_datom_ids,
607                                imported_cuts: merge_partition_cuts(
608                                    state.imported_cuts.iter().chain(row.imported_cuts.iter()),
609                                ),
610                                query_tuple_id: row.tuple_id.or(state.query_tuple_id),
611                                policy: merge_policy_envelopes([
612                                    state.policy.as_ref(),
613                                    row.policy.as_ref(),
614                                ]),
615                            });
616                        }
617                    }
618                }
619
620                states = next_states;
621            }
622            Literal::Negative(atom) => {
623                if current_scc_predicates.contains(&atom.predicate.id) {
624                    return Err(RuntimeError::UnsupportedIntraStratumNegation(rule.id));
625                }
626                let rows = negative_relation_rows(
627                    atom.predicate.id,
628                    derived_rows,
629                    extensional_rows,
630                    intensional_predicates,
631                )?;
632                states.retain(|state| {
633                    !rows
634                        .iter()
635                        .any(|row| unify_terms(&state.bindings, &atom.terms, &row.values).is_some())
636                });
637            }
638        }
639
640        if states.is_empty() {
641            break;
642        }
643    }
644
645    Ok(states)
646}
647
648fn ordered_rule_body(rule: &RuleAst) -> Vec<(usize, &Literal)> {
649    let mut positives = Vec::new();
650    let mut negatives = Vec::new();
651    for (index, literal) in rule.body.iter().enumerate() {
652        match literal {
653            Literal::Positive(_) => positives.push((index, literal)),
654            Literal::Negative(_) => negatives.push((index, literal)),
655        }
656    }
657    positives.extend(negatives);
658    positives
659}
660
661fn positive_relation_rows(
662    predicate: PredicateId,
663    use_delta: Option<()>,
664    derived_rows: &IndexMap<PredicateId, Vec<RelationRow>>,
665    delta_rows: &IndexMap<PredicateId, Vec<RelationRow>>,
666    extensional_rows: &IndexMap<PredicateId, Vec<RelationRow>>,
667    intensional_predicates: &IndexSet<PredicateId>,
668    _current_scc_predicates: &IndexSet<PredicateId>,
669) -> Result<Vec<RelationRow>, RuntimeError> {
670    if use_delta.is_some() {
671        return Ok(delta_rows.get(&predicate).cloned().unwrap_or_default());
672    }
673    if intensional_predicates.contains(&predicate) {
674        return Ok(derived_rows.get(&predicate).cloned().unwrap_or_default());
675    }
676
677    extensional_rows
678        .get(&predicate)
679        .cloned()
680        .ok_or(RuntimeError::MissingExtensionalBinding(predicate))
681}
682
683fn negative_relation_rows(
684    predicate: PredicateId,
685    derived_rows: &IndexMap<PredicateId, Vec<RelationRow>>,
686    extensional_rows: &IndexMap<PredicateId, Vec<RelationRow>>,
687    intensional_predicates: &IndexSet<PredicateId>,
688) -> Result<Vec<RelationRow>, RuntimeError> {
689    if intensional_predicates.contains(&predicate) {
690        Ok(derived_rows.get(&predicate).cloned().unwrap_or_default())
691    } else {
692        extensional_rows
693            .get(&predicate)
694            .cloned()
695            .ok_or(RuntimeError::MissingExtensionalBinding(predicate))
696    }
697}
698
699fn unify_terms(
700    bindings: &IndexMap<Variable, Value>,
701    terms: &[Term],
702    values: &[Value],
703) -> Option<IndexMap<Variable, Value>> {
704    if terms.len() != values.len() {
705        return None;
706    }
707
708    let mut next_bindings = bindings.clone();
709    for (term, value) in terms.iter().zip(values) {
710        match term {
711            Term::Variable(variable) => match next_bindings.get(variable) {
712                Some(bound) if bound != value => return None,
713                Some(_) => {}
714                None => {
715                    next_bindings.insert(variable.clone(), value.clone());
716                }
717            },
718            Term::Value(expected) if expected != value => return None,
719            Term::Value(_) => {}
720            Term::Aggregate(_) => return None,
721        }
722    }
723
724    Some(next_bindings)
725}
726
727fn materialize_non_aggregate_head(
728    rule_id: RuleId,
729    terms: &[Term],
730    bindings: &IndexMap<Variable, Value>,
731) -> Result<Vec<Value>, RuntimeError> {
732    terms
733        .iter()
734        .map(|term| match term {
735            Term::Variable(variable) => {
736                bindings
737                    .get(variable)
738                    .cloned()
739                    .ok_or_else(|| RuntimeError::UnboundVariable {
740                        rule_id,
741                        variable: variable.0.clone(),
742                    })
743            }
744            Term::Value(value) => Ok(value.clone()),
745            Term::Aggregate(_) => Err(RuntimeError::UnexpectedAggregate(rule_id)),
746        })
747        .collect()
748}
749
750fn materialize_aggregate_head(
751    rule_id: RuleId,
752    terms: &[Term],
753    aggregates: &[(usize, &AggregateTerm)],
754    matches: &[MatchState],
755) -> Result<Vec<AggregatedMatch>, RuntimeError> {
756    let mut groups: IndexMap<String, AggregateGroup> = IndexMap::new();
757
758    for matched in matches {
759        let binding_key = bindings_key(&matched.bindings);
760        let group_values = materialize_group_values(rule_id, terms, aggregates, &matched.bindings)?;
761        let group_key = values_key(&group_values);
762
763        if !groups.contains_key(&group_key) {
764            let accumulators = aggregates
765                .iter()
766                .map(|(_, aggregate_term)| {
767                    let aggregate_value = matched
768                        .bindings
769                        .get(&aggregate_term.variable)
770                        .ok_or_else(|| RuntimeError::UnboundVariable {
771                            rule_id,
772                            variable: aggregate_term.variable.0.clone(),
773                        })?;
774                    AggregateAccumulator::from_value(
775                        rule_id,
776                        aggregate_term.function,
777                        aggregate_value,
778                    )
779                })
780                .collect::<Result<Vec<_>, _>>()?;
781            groups.insert(
782                group_key.clone(),
783                AggregateGroup {
784                    values: group_values.into_iter().map(Some).collect(),
785                    accumulators,
786                    seen_bindings: IndexSet::new(),
787                    parent_tuple_ids: Vec::new(),
788                    source_datom_ids: Vec::new(),
789                    imported_cuts: Vec::new(),
790                    policy: None,
791                },
792            );
793        }
794        let group = groups
795            .get_mut(&group_key)
796            .expect("aggregate group should exist after insertion");
797
798        if !group.seen_bindings.insert(binding_key) {
799            continue;
800        }
801
802        if group.seen_bindings.len() > 1 {
803            for (accumulator, (_, aggregate_term)) in
804                group.accumulators.iter_mut().zip(aggregates.iter())
805            {
806                let aggregate_value =
807                    matched
808                        .bindings
809                        .get(&aggregate_term.variable)
810                        .ok_or_else(|| RuntimeError::UnboundVariable {
811                            rule_id,
812                            variable: aggregate_term.variable.0.clone(),
813                        })?;
814                accumulator.add(rule_id, aggregate_term.function, aggregate_value)?;
815            }
816        }
817        extend_unique(&mut group.parent_tuple_ids, &matched.parent_tuple_ids);
818        extend_unique(&mut group.source_datom_ids, &matched.source_datom_ids);
819        group.imported_cuts = merge_partition_cuts(
820            group
821                .imported_cuts
822                .iter()
823                .chain(matched.imported_cuts.iter()),
824        );
825        group.policy = merge_policy_envelopes([group.policy.as_ref(), matched.policy.as_ref()]);
826    }
827
828    let mut aggregated = groups
829        .into_values()
830        .map(|mut group| {
831            let mut values = group
832                .values
833                .into_iter()
834                .map(|value| value.expect("group values are initialized"))
835                .collect::<Vec<_>>();
836            for ((aggregate_index, _), accumulator) in
837                aggregates.iter().zip(group.accumulators.drain(..))
838            {
839                values[*aggregate_index] = accumulator.finalize();
840            }
841            AggregatedMatch {
842                values,
843                parent_tuple_ids: group.parent_tuple_ids,
844                source_datom_ids: group.source_datom_ids,
845                imported_cuts: group.imported_cuts,
846                policy: group.policy,
847            }
848        })
849        .collect::<Vec<_>>();
850
851    aggregated.sort_by_key(|group| values_key(&group.values));
852    Ok(aggregated)
853}
854
855fn materialize_group_values(
856    rule_id: RuleId,
857    terms: &[Term],
858    aggregates: &[(usize, &AggregateTerm)],
859    bindings: &IndexMap<Variable, Value>,
860) -> Result<Vec<Value>, RuntimeError> {
861    terms
862        .iter()
863        .enumerate()
864        .map(|(index, term)| {
865            if aggregates
866                .iter()
867                .any(|(aggregate_index, _)| index == *aggregate_index)
868            {
869                return Ok(Value::Null);
870            }
871            match term {
872                Term::Variable(variable) => {
873                    bindings
874                        .get(variable)
875                        .cloned()
876                        .ok_or_else(|| RuntimeError::UnboundVariable {
877                            rule_id,
878                            variable: variable.0.clone(),
879                        })
880                }
881                Term::Value(value) => Ok(value.clone()),
882                Term::Aggregate(_) => Err(RuntimeError::UnexpectedAggregate(rule_id)),
883            }
884        })
885        .collect()
886}
887
888fn head_aggregates(rule: &RuleAst) -> Vec<(usize, &AggregateTerm)> {
889    rule.head
890        .terms
891        .iter()
892        .enumerate()
893        .filter_map(|(index, term)| match term {
894            Term::Aggregate(aggregate) => Some((index, aggregate)),
895            _ => None,
896        })
897        .collect()
898}
899
900fn bindings_key(bindings: &IndexMap<Variable, Value>) -> String {
901    let mut entries = bindings
902        .iter()
903        .map(|(variable, value)| (variable.0.as_str(), value_key(value)))
904        .collect::<Vec<_>>();
905    entries.sort_unstable_by(|left, right| left.0.cmp(right.0));
906
907    let mut rendered = String::new();
908    for (variable, value) in entries {
909        rendered.push_str(variable);
910        rendered.push('=');
911        rendered.push_str(&value);
912        rendered.push('|');
913    }
914    rendered
915}
916
917fn values_key(values: &[Value]) -> String {
918    let mut rendered = String::new();
919    for value in values {
920        rendered.push_str(&value_key(value));
921        rendered.push('|');
922    }
923    rendered
924}
925
926fn build_scc_lookup(program: &CompiledProgram) -> IndexMap<PredicateId, usize> {
927    let mut lookup = IndexMap::new();
928    for scc in &program.sccs {
929        for predicate in &scc.predicates {
930            lookup.insert(*predicate, scc.id);
931        }
932    }
933    lookup
934}
935
936impl AggregateAccumulator {
937    fn from_value(
938        rule_id: RuleId,
939        function: AggregateFunction,
940        value: &Value,
941    ) -> Result<Self, RuntimeError> {
942        match function {
943            AggregateFunction::Count => Ok(Self::Count(1)),
944            AggregateFunction::Sum => match value {
945                Value::I64(inner) => Ok(Self::SumI64(*inner)),
946                Value::U64(inner) => Ok(Self::SumU64(*inner)),
947                Value::F64(inner) => Ok(Self::SumF64(*inner)),
948                other => Err(RuntimeError::UnsupportedAggregateInput {
949                    rule_id,
950                    function,
951                    actual: runtime_value_type(other),
952                }),
953            },
954            AggregateFunction::Min => {
955                validate_orderable_input(rule_id, function, value).map(|_| Self::Min(value.clone()))
956            }
957            AggregateFunction::Max => {
958                validate_orderable_input(rule_id, function, value).map(|_| Self::Max(value.clone()))
959            }
960        }
961    }
962
963    fn add(
964        &mut self,
965        rule_id: RuleId,
966        function: AggregateFunction,
967        value: &Value,
968    ) -> Result<(), RuntimeError> {
969        match self {
970            Self::Count(count) => {
971                *count += 1;
972                Ok(())
973            }
974            Self::SumI64(total) => match value {
975                Value::I64(inner) => {
976                    *total += inner;
977                    Ok(())
978                }
979                other => Err(RuntimeError::AggregateInputTypeMismatch {
980                    rule_id,
981                    function,
982                    expected: "I64".into(),
983                    actual: runtime_value_type(other),
984                }),
985            },
986            Self::SumU64(total) => match value {
987                Value::U64(inner) => {
988                    *total += inner;
989                    Ok(())
990                }
991                other => Err(RuntimeError::AggregateInputTypeMismatch {
992                    rule_id,
993                    function,
994                    expected: "U64".into(),
995                    actual: runtime_value_type(other),
996                }),
997            },
998            Self::SumF64(total) => match value {
999                Value::F64(inner) => {
1000                    *total += inner;
1001                    Ok(())
1002                }
1003                other => Err(RuntimeError::AggregateInputTypeMismatch {
1004                    rule_id,
1005                    function,
1006                    expected: "F64".into(),
1007                    actual: runtime_value_type(other),
1008                }),
1009            },
1010            Self::Min(current) => {
1011                validate_orderable_input(rule_id, function, value)?;
1012                if compare_values(current, value)? == Ordering::Greater {
1013                    *current = value.clone();
1014                }
1015                Ok(())
1016            }
1017            Self::Max(current) => {
1018                validate_orderable_input(rule_id, function, value)?;
1019                if compare_values(current, value)? == Ordering::Less {
1020                    *current = value.clone();
1021                }
1022                Ok(())
1023            }
1024        }
1025    }
1026
1027    fn finalize(self) -> Value {
1028        match self {
1029            Self::Count(inner) => Value::U64(inner),
1030            Self::SumI64(inner) => Value::I64(inner),
1031            Self::SumU64(inner) => Value::U64(inner),
1032            Self::SumF64(inner) => Value::F64(inner),
1033            Self::Min(inner) | Self::Max(inner) => inner,
1034        }
1035    }
1036}
1037
1038fn validate_orderable_input(
1039    rule_id: RuleId,
1040    function: AggregateFunction,
1041    value: &Value,
1042) -> Result<(), RuntimeError> {
1043    match value {
1044        Value::I64(_) | Value::U64(_) | Value::F64(_) | Value::String(_) | Value::Entity(_) => {
1045            Ok(())
1046        }
1047        other => Err(RuntimeError::UnsupportedAggregateInput {
1048            rule_id,
1049            function,
1050            actual: runtime_value_type(other),
1051        }),
1052    }
1053}
1054
1055fn compare_values(left: &Value, right: &Value) -> Result<Ordering, RuntimeError> {
1056    match (left, right) {
1057        (Value::I64(left_inner), Value::I64(right_inner)) => Ok(left_inner.cmp(right_inner)),
1058        (Value::U64(left_inner), Value::U64(right_inner)) => Ok(left_inner.cmp(right_inner)),
1059        (Value::F64(left_inner), Value::F64(right_inner)) => left_inner
1060            .partial_cmp(right_inner)
1061            .ok_or_else(|| RuntimeError::NonComparableAggregateValues {
1062                left: runtime_value_type(left),
1063                right: runtime_value_type(right),
1064            }),
1065        (Value::String(left_inner), Value::String(right_inner)) => Ok(left_inner.cmp(right_inner)),
1066        (Value::Entity(left_inner), Value::Entity(right_inner)) => Ok(left_inner.cmp(right_inner)),
1067        _ => Err(RuntimeError::NonComparableAggregateValues {
1068            left: runtime_value_type(left),
1069            right: runtime_value_type(right),
1070        }),
1071    }
1072}
1073
1074fn tuple_key(predicate: PredicateId, values: &[Value]) -> String {
1075    let mut key = format!("{}#", predicate.0);
1076    for value in values {
1077        key.push_str(&value_key(value));
1078        key.push('|');
1079    }
1080    key
1081}
1082
1083fn extend_unique<T>(target: &mut Vec<T>, additions: &[T])
1084where
1085    T: Copy + Eq,
1086{
1087    for addition in additions {
1088        if !target.contains(addition) {
1089            target.push(*addition);
1090        }
1091    }
1092}
1093
1094fn value_key(value: &Value) -> String {
1095    match value {
1096        Value::Null => "null".into(),
1097        Value::Bool(inner) => format!("bool:{inner}"),
1098        Value::I64(inner) => format!("i64:{inner}"),
1099        Value::U64(inner) => format!("u64:{inner}"),
1100        Value::F64(inner) => format!("f64:{:016x}", inner.to_bits()),
1101        Value::String(inner) => format!("string:{}:{inner}", inner.len()),
1102        Value::Bytes(inner) => format!("bytes:{inner:?}"),
1103        Value::Entity(inner) => format!("entity:{}", inner.0),
1104        Value::List(inner) => {
1105            let mut rendered = String::from("list:[");
1106            for value in inner {
1107                rendered.push_str(&value_key(value));
1108                rendered.push(',');
1109            }
1110            rendered.push(']');
1111            rendered
1112        }
1113    }
1114}
1115
1116fn runtime_value_type(value: &Value) -> String {
1117    match value {
1118        Value::Null => "Null".into(),
1119        Value::Bool(_) => "Bool".into(),
1120        Value::I64(_) => "I64".into(),
1121        Value::U64(_) => "U64".into(),
1122        Value::F64(_) => "F64".into(),
1123        Value::String(_) => "String".into(),
1124        Value::Bytes(_) => "Bytes".into(),
1125        Value::Entity(_) => "Entity".into(),
1126        Value::List(_) => "List".into(),
1127    }
1128}
1129
1130#[derive(Debug, Error)]
1131pub enum RuntimeError {
1132    #[error("predicate {0} has no extensional binding or fact rows in the compiled program")]
1133    MissingExtensionalBinding(PredicateId),
1134    #[error("rule {0} uses same-stratum negation, which is not supported")]
1135    UnsupportedIntraStratumNegation(RuleId),
1136    #[error("rule {rule_id} references unbound variable {variable}")]
1137    UnboundVariable { rule_id: RuleId, variable: String },
1138    #[error(
1139        "rule {0} requires grouped aggregate materialization, but was evaluated as a plain rule"
1140    )]
1141    UnexpectedAggregate(RuleId),
1142    #[error(
1143        "rule {rule_id} uses aggregate {function} over unsupported runtime value type {actual}"
1144    )]
1145    UnsupportedAggregateInput {
1146        rule_id: RuleId,
1147        function: AggregateFunction,
1148        actual: String,
1149    },
1150    #[error(
1151        "rule {rule_id} uses aggregate {function} with mixed runtime input types: expected {expected}, found {actual}"
1152    )]
1153    AggregateInputTypeMismatch {
1154        rule_id: RuleId,
1155        function: AggregateFunction,
1156        expected: String,
1157        actual: String,
1158    },
1159    #[error("aggregate comparison requires comparable values, found {left} and {right}")]
1160    NonComparableAggregateValues { left: String, right: String },
1161}
1162
1163#[cfg(test)]
1164mod tests {
1165    use super::{execute_query, RuleRuntime, RuntimeError, SemiNaiveRuntime};
1166    use aether_ast::{
1167        AggregateFunction, AggregateTerm, Atom, AttributeId, Datom, DatomProvenance, ElementId,
1168        EntityId, ExtensionalFact, Literal, PredicateId, PredicateRef, QueryAst, QueryRow, RuleAst,
1169        RuleId, RuleProgram, Term, Value, Variable,
1170    };
1171    use aether_resolver::{MaterializedResolver, Resolver};
1172    use aether_rules::{DefaultRuleCompiler, RuleCompiler};
1173    use aether_schema::{AttributeClass, AttributeSchema, PredicateSignature, Schema, ValueType};
1174
1175    fn predicate(id: u64, name: &str, arity: usize) -> PredicateRef {
1176        PredicateRef {
1177            id: PredicateId::new(id),
1178            name: name.into(),
1179            arity,
1180        }
1181    }
1182
1183    fn atom(predicate: PredicateRef, vars: &[&str]) -> Atom {
1184        Atom {
1185            predicate,
1186            terms: vars
1187                .iter()
1188                .map(|name| Term::Variable(Variable::new(*name)))
1189                .collect(),
1190        }
1191    }
1192
1193    fn aggregate(function: AggregateFunction, variable: &str) -> Term {
1194        Term::Aggregate(AggregateTerm {
1195            function,
1196            variable: Variable::new(variable),
1197        })
1198    }
1199
1200    fn dependency_datom(entity: u64, value: u64, element: u64) -> Datom {
1201        Datom {
1202            entity: EntityId::new(entity),
1203            attribute: AttributeId::new(1),
1204            value: Value::Entity(EntityId::new(value)),
1205            op: aether_ast::OperationKind::Add,
1206            element: ElementId::new(element),
1207            replica: aether_ast::ReplicaId::new(1),
1208            causal_context: Default::default(),
1209            provenance: DatomProvenance::default(),
1210            policy: None,
1211        }
1212    }
1213
1214    #[test]
1215    fn monotone_transitive_closure_converges_with_iteration_metadata() {
1216        let mut schema = Schema::new("v1");
1217        schema
1218            .register_attribute(AttributeSchema {
1219                id: AttributeId::new(1),
1220                name: "task.depends_on".into(),
1221                class: AttributeClass::RefSet,
1222                value_type: ValueType::Entity,
1223            })
1224            .expect("register attribute");
1225        schema
1226            .register_predicate(PredicateSignature {
1227                id: PredicateId::new(1),
1228                name: "task_depends_on".into(),
1229                fields: vec![ValueType::Entity, ValueType::Entity],
1230            })
1231            .expect("register extensional predicate");
1232        schema
1233            .register_predicate(PredicateSignature {
1234                id: PredicateId::new(2),
1235                name: "depends_transitive".into(),
1236                fields: vec![ValueType::Entity, ValueType::Entity],
1237            })
1238            .expect("register recursive predicate");
1239
1240        let program = RuleProgram {
1241            predicates: vec![
1242                predicate(1, "task_depends_on", 2),
1243                predicate(2, "depends_transitive", 2),
1244            ],
1245            rules: vec![
1246                RuleAst {
1247                    id: RuleId::new(1),
1248                    head: atom(predicate(2, "depends_transitive", 2), &["x", "y"]),
1249                    body: vec![Literal::Positive(atom(
1250                        predicate(1, "task_depends_on", 2),
1251                        &["x", "y"],
1252                    ))],
1253                },
1254                RuleAst {
1255                    id: RuleId::new(2),
1256                    head: atom(predicate(2, "depends_transitive", 2), &["x", "z"]),
1257                    body: vec![
1258                        Literal::Positive(atom(predicate(2, "depends_transitive", 2), &["x", "y"])),
1259                        Literal::Positive(atom(predicate(1, "task_depends_on", 2), &["y", "z"])),
1260                    ],
1261                },
1262            ],
1263            materialized: vec![PredicateId::new(2)],
1264            facts: Vec::new(),
1265        };
1266        let datoms = vec![
1267            dependency_datom(1, 2, 1),
1268            dependency_datom(2, 3, 2),
1269            dependency_datom(3, 4, 3),
1270        ];
1271        let state = MaterializedResolver
1272            .current(&schema, &datoms)
1273            .expect("resolve current state");
1274        let compiled = DefaultRuleCompiler
1275            .compile(&schema, &program)
1276            .expect("compile recursive program");
1277
1278        let derived = SemiNaiveRuntime
1279            .evaluate(&state, &compiled)
1280            .expect("evaluate recursive closure");
1281
1282        let mut pairs = derived
1283            .tuples
1284            .iter()
1285            .map(|tuple| {
1286                let [Value::Entity(left), Value::Entity(right)] = &tuple.tuple.values[..] else {
1287                    panic!("expected binary entity tuple");
1288                };
1289                (left.0, right.0)
1290            })
1291            .collect::<Vec<_>>();
1292        pairs.sort_unstable();
1293
1294        assert_eq!(pairs, vec![(1, 2), (1, 3), (1, 4), (2, 3), (2, 4), (3, 4)]);
1295        assert_eq!(
1296            derived
1297                .iterations
1298                .iter()
1299                .map(|iteration| iteration.delta_size)
1300                .collect::<Vec<_>>(),
1301            vec![3, 2, 1, 0]
1302        );
1303        let longest_path = derived
1304            .tuples
1305            .iter()
1306            .find(|tuple| {
1307                tuple.tuple.values
1308                    == vec![
1309                        Value::Entity(EntityId::new(1)),
1310                        Value::Entity(EntityId::new(4)),
1311                    ]
1312            })
1313            .expect("longest-path tuple");
1314        assert_eq!(
1315            longest_path.metadata.source_datom_ids,
1316            vec![ElementId::new(1), ElementId::new(2), ElementId::new(3)]
1317        );
1318    }
1319
1320    #[test]
1321    fn bounded_aggregation_materializes_counts_sums_and_maxima() {
1322        let mut schema = Schema::new("v1");
1323        for signature in [
1324            PredicateSignature {
1325                id: PredicateId::new(1),
1326                name: "edge".into(),
1327                fields: vec![ValueType::Entity, ValueType::Entity],
1328            },
1329            PredicateSignature {
1330                id: PredicateId::new(2),
1331                name: "reach".into(),
1332                fields: vec![ValueType::Entity, ValueType::Entity],
1333            },
1334            PredicateSignature {
1335                id: PredicateId::new(3),
1336                name: "reachable_count".into(),
1337                fields: vec![ValueType::Entity, ValueType::U64],
1338            },
1339            PredicateSignature {
1340                id: PredicateId::new(4),
1341                name: "project_task".into(),
1342                fields: vec![ValueType::Entity, ValueType::Entity],
1343            },
1344            PredicateSignature {
1345                id: PredicateId::new(5),
1346                name: "task_hours".into(),
1347                fields: vec![ValueType::Entity, ValueType::U64],
1348            },
1349            PredicateSignature {
1350                id: PredicateId::new(6),
1351                name: "project_hours".into(),
1352                fields: vec![ValueType::Entity, ValueType::U64],
1353            },
1354            PredicateSignature {
1355                id: PredicateId::new(9),
1356                name: "project_stats".into(),
1357                fields: vec![ValueType::Entity, ValueType::U64, ValueType::U64],
1358            },
1359            PredicateSignature {
1360                id: PredicateId::new(7),
1361                name: "execution_attempt".into(),
1362                fields: vec![ValueType::Entity, ValueType::String, ValueType::U64],
1363            },
1364            PredicateSignature {
1365                id: PredicateId::new(8),
1366                name: "latest_epoch".into(),
1367                fields: vec![ValueType::Entity, ValueType::U64],
1368            },
1369        ] {
1370            schema
1371                .register_predicate(signature)
1372                .expect("register predicate");
1373        }
1374
1375        let program = RuleProgram {
1376            predicates: vec![
1377                predicate(1, "edge", 2),
1378                predicate(2, "reach", 2),
1379                predicate(3, "reachable_count", 2),
1380                predicate(4, "project_task", 2),
1381                predicate(5, "task_hours", 2),
1382                predicate(6, "project_hours", 2),
1383                predicate(9, "project_stats", 3),
1384                predicate(7, "execution_attempt", 3),
1385                predicate(8, "latest_epoch", 2),
1386            ],
1387            rules: vec![
1388                RuleAst {
1389                    id: RuleId::new(1),
1390                    head: atom(predicate(2, "reach", 2), &["x", "y"]),
1391                    body: vec![Literal::Positive(atom(
1392                        predicate(1, "edge", 2),
1393                        &["x", "y"],
1394                    ))],
1395                },
1396                RuleAst {
1397                    id: RuleId::new(2),
1398                    head: atom(predicate(2, "reach", 2), &["x", "z"]),
1399                    body: vec![
1400                        Literal::Positive(atom(predicate(2, "reach", 2), &["x", "y"])),
1401                        Literal::Positive(atom(predicate(1, "edge", 2), &["y", "z"])),
1402                    ],
1403                },
1404                RuleAst {
1405                    id: RuleId::new(3),
1406                    head: Atom {
1407                        predicate: predicate(3, "reachable_count", 2),
1408                        terms: vec![
1409                            Term::Variable(Variable::new("x")),
1410                            aggregate(AggregateFunction::Count, "y"),
1411                        ],
1412                    },
1413                    body: vec![Literal::Positive(atom(
1414                        predicate(2, "reach", 2),
1415                        &["x", "y"],
1416                    ))],
1417                },
1418                RuleAst {
1419                    id: RuleId::new(4),
1420                    head: Atom {
1421                        predicate: predicate(6, "project_hours", 2),
1422                        terms: vec![
1423                            Term::Variable(Variable::new("project")),
1424                            aggregate(AggregateFunction::Sum, "hours"),
1425                        ],
1426                    },
1427                    body: vec![
1428                        Literal::Positive(atom(
1429                            predicate(4, "project_task", 2),
1430                            &["project", "task"],
1431                        )),
1432                        Literal::Positive(atom(predicate(5, "task_hours", 2), &["task", "hours"])),
1433                    ],
1434                },
1435                RuleAst {
1436                    id: RuleId::new(5),
1437                    head: Atom {
1438                        predicate: predicate(9, "project_stats", 3),
1439                        terms: vec![
1440                            Term::Variable(Variable::new("project")),
1441                            aggregate(AggregateFunction::Count, "task"),
1442                            aggregate(AggregateFunction::Sum, "hours"),
1443                        ],
1444                    },
1445                    body: vec![
1446                        Literal::Positive(atom(
1447                            predicate(4, "project_task", 2),
1448                            &["project", "task"],
1449                        )),
1450                        Literal::Positive(atom(predicate(5, "task_hours", 2), &["task", "hours"])),
1451                    ],
1452                },
1453                RuleAst {
1454                    id: RuleId::new(6),
1455                    head: Atom {
1456                        predicate: predicate(8, "latest_epoch", 2),
1457                        terms: vec![
1458                            Term::Variable(Variable::new("task")),
1459                            aggregate(AggregateFunction::Max, "epoch"),
1460                        ],
1461                    },
1462                    body: vec![Literal::Positive(atom(
1463                        predicate(7, "execution_attempt", 3),
1464                        &["task", "worker", "epoch"],
1465                    ))],
1466                },
1467            ],
1468            materialized: vec![
1469                PredicateId::new(2),
1470                PredicateId::new(3),
1471                PredicateId::new(6),
1472                PredicateId::new(9),
1473                PredicateId::new(8),
1474            ],
1475            facts: vec![
1476                ExtensionalFact {
1477                    predicate: predicate(1, "edge", 2),
1478                    values: vec![
1479                        Value::Entity(EntityId::new(1)),
1480                        Value::Entity(EntityId::new(2)),
1481                    ],
1482                    policy: None,
1483                    provenance: None,
1484                },
1485                ExtensionalFact {
1486                    predicate: predicate(1, "edge", 2),
1487                    values: vec![
1488                        Value::Entity(EntityId::new(2)),
1489                        Value::Entity(EntityId::new(3)),
1490                    ],
1491                    policy: None,
1492                    provenance: None,
1493                },
1494                ExtensionalFact {
1495                    predicate: predicate(1, "edge", 2),
1496                    values: vec![
1497                        Value::Entity(EntityId::new(3)),
1498                        Value::Entity(EntityId::new(4)),
1499                    ],
1500                    policy: None,
1501                    provenance: None,
1502                },
1503                ExtensionalFact {
1504                    predicate: predicate(4, "project_task", 2),
1505                    values: vec![
1506                        Value::Entity(EntityId::new(10)),
1507                        Value::Entity(EntityId::new(101)),
1508                    ],
1509                    policy: None,
1510                    provenance: None,
1511                },
1512                ExtensionalFact {
1513                    predicate: predicate(4, "project_task", 2),
1514                    values: vec![
1515                        Value::Entity(EntityId::new(10)),
1516                        Value::Entity(EntityId::new(102)),
1517                    ],
1518                    policy: None,
1519                    provenance: None,
1520                },
1521                ExtensionalFact {
1522                    predicate: predicate(5, "task_hours", 2),
1523                    values: vec![Value::Entity(EntityId::new(101)), Value::U64(3)],
1524                    policy: None,
1525                    provenance: None,
1526                },
1527                ExtensionalFact {
1528                    predicate: predicate(5, "task_hours", 2),
1529                    values: vec![Value::Entity(EntityId::new(102)), Value::U64(5)],
1530                    policy: None,
1531                    provenance: None,
1532                },
1533                ExtensionalFact {
1534                    predicate: predicate(7, "execution_attempt", 3),
1535                    values: vec![
1536                        Value::Entity(EntityId::new(1)),
1537                        Value::String("worker-a".into()),
1538                        Value::U64(1),
1539                    ],
1540                    policy: None,
1541                    provenance: None,
1542                },
1543                ExtensionalFact {
1544                    predicate: predicate(7, "execution_attempt", 3),
1545                    values: vec![
1546                        Value::Entity(EntityId::new(1)),
1547                        Value::String("worker-b".into()),
1548                        Value::U64(4),
1549                    ],
1550                    policy: None,
1551                    provenance: None,
1552                },
1553            ],
1554        };
1555
1556        let compiled = DefaultRuleCompiler
1557            .compile(&schema, &program)
1558            .expect("compile aggregate program");
1559        let derived = SemiNaiveRuntime
1560            .evaluate(&Default::default(), &compiled)
1561            .expect("evaluate aggregate program");
1562
1563        let reachable_count = execute_query(
1564            &Default::default(),
1565            &compiled,
1566            &derived,
1567            &QueryAst {
1568                goals: vec![atom(predicate(3, "reachable_count", 2), &["x", "count"])],
1569                keep: vec![Variable::new("x"), Variable::new("count")],
1570            },
1571            None,
1572        )
1573        .expect("query reachable count");
1574        assert_eq!(
1575            reachable_count.rows,
1576            vec![
1577                QueryRow {
1578                    values: vec![Value::Entity(EntityId::new(1)), Value::U64(3)],
1579                    tuple_id: reachable_count.rows[0].tuple_id,
1580                },
1581                QueryRow {
1582                    values: vec![Value::Entity(EntityId::new(2)), Value::U64(2)],
1583                    tuple_id: reachable_count.rows[1].tuple_id,
1584                },
1585                QueryRow {
1586                    values: vec![Value::Entity(EntityId::new(3)), Value::U64(1)],
1587                    tuple_id: reachable_count.rows[2].tuple_id,
1588                },
1589            ]
1590        );
1591
1592        let project_hours = execute_query(
1593            &Default::default(),
1594            &compiled,
1595            &derived,
1596            &QueryAst {
1597                goals: vec![atom(
1598                    predicate(6, "project_hours", 2),
1599                    &["project", "hours"],
1600                )],
1601                keep: vec![Variable::new("project"), Variable::new("hours")],
1602            },
1603            None,
1604        )
1605        .expect("query project hours");
1606        assert_eq!(
1607            project_hours.rows[0].values,
1608            vec![Value::Entity(EntityId::new(10)), Value::U64(8)]
1609        );
1610
1611        let project_stats = execute_query(
1612            &Default::default(),
1613            &compiled,
1614            &derived,
1615            &QueryAst {
1616                goals: vec![atom(
1617                    predicate(9, "project_stats", 3),
1618                    &["project", "task_count", "hours"],
1619                )],
1620                keep: vec![
1621                    Variable::new("project"),
1622                    Variable::new("task_count"),
1623                    Variable::new("hours"),
1624                ],
1625            },
1626            None,
1627        )
1628        .expect("query project stats");
1629        assert_eq!(
1630            project_stats.rows[0].values,
1631            vec![
1632                Value::Entity(EntityId::new(10)),
1633                Value::U64(2),
1634                Value::U64(8),
1635            ]
1636        );
1637
1638        let latest_epoch = execute_query(
1639            &Default::default(),
1640            &compiled,
1641            &derived,
1642            &QueryAst {
1643                goals: vec![atom(predicate(8, "latest_epoch", 2), &["task", "epoch"])],
1644                keep: vec![Variable::new("task"), Variable::new("epoch")],
1645            },
1646            None,
1647        )
1648        .expect("query latest epoch");
1649        assert_eq!(
1650            latest_epoch.rows[0].values,
1651            vec![Value::Entity(EntityId::new(1)), Value::U64(4)]
1652        );
1653    }
1654
1655    #[test]
1656    fn stratified_negation_supports_readiness_and_stale_rejection() {
1657        let mut schema = Schema::new("v1");
1658        for attribute in [
1659            AttributeSchema {
1660                id: AttributeId::new(1),
1661                name: "task.depends_on".into(),
1662                class: AttributeClass::RefSet,
1663                value_type: ValueType::Entity,
1664            },
1665            AttributeSchema {
1666                id: AttributeId::new(2),
1667                name: "task.status".into(),
1668                class: AttributeClass::ScalarLww,
1669                value_type: ValueType::String,
1670            },
1671            AttributeSchema {
1672                id: AttributeId::new(3),
1673                name: "task.claimed_by".into(),
1674                class: AttributeClass::ScalarLww,
1675                value_type: ValueType::String,
1676            },
1677            AttributeSchema {
1678                id: AttributeId::new(4),
1679                name: "task.lease_epoch".into(),
1680                class: AttributeClass::ScalarLww,
1681                value_type: ValueType::U64,
1682            },
1683            AttributeSchema {
1684                id: AttributeId::new(5),
1685                name: "task.lease_state".into(),
1686                class: AttributeClass::ScalarLww,
1687                value_type: ValueType::String,
1688            },
1689        ] {
1690            schema
1691                .register_attribute(attribute)
1692                .expect("register attribute");
1693        }
1694
1695        for signature in [
1696            PredicateSignature {
1697                id: PredicateId::new(1),
1698                name: "task".into(),
1699                fields: vec![ValueType::Entity],
1700            },
1701            PredicateSignature {
1702                id: PredicateId::new(2),
1703                name: "execution_attempt".into(),
1704                fields: vec![ValueType::Entity, ValueType::String, ValueType::U64],
1705            },
1706            PredicateSignature {
1707                id: PredicateId::new(3),
1708                name: "task_depends_on".into(),
1709                fields: vec![ValueType::Entity, ValueType::Entity],
1710            },
1711            PredicateSignature {
1712                id: PredicateId::new(4),
1713                name: "task_status".into(),
1714                fields: vec![ValueType::Entity, ValueType::String],
1715            },
1716            PredicateSignature {
1717                id: PredicateId::new(5),
1718                name: "task_claimed_by".into(),
1719                fields: vec![ValueType::Entity, ValueType::String],
1720            },
1721            PredicateSignature {
1722                id: PredicateId::new(6),
1723                name: "task_lease_epoch".into(),
1724                fields: vec![ValueType::Entity, ValueType::U64],
1725            },
1726            PredicateSignature {
1727                id: PredicateId::new(7),
1728                name: "task_lease_state".into(),
1729                fields: vec![ValueType::Entity, ValueType::String],
1730            },
1731            PredicateSignature {
1732                id: PredicateId::new(8),
1733                name: "task_complete".into(),
1734                fields: vec![ValueType::Entity],
1735            },
1736            PredicateSignature {
1737                id: PredicateId::new(9),
1738                name: "dependency_blocked".into(),
1739                fields: vec![ValueType::Entity],
1740            },
1741            PredicateSignature {
1742                id: PredicateId::new(10),
1743                name: "lease_active".into(),
1744                fields: vec![ValueType::Entity, ValueType::String, ValueType::U64],
1745            },
1746            PredicateSignature {
1747                id: PredicateId::new(11),
1748                name: "active_claim".into(),
1749                fields: vec![ValueType::Entity],
1750            },
1751            PredicateSignature {
1752                id: PredicateId::new(12),
1753                name: "task_ready".into(),
1754                fields: vec![ValueType::Entity],
1755            },
1756            PredicateSignature {
1757                id: PredicateId::new(13),
1758                name: "execution_rejected_stale".into(),
1759                fields: vec![ValueType::Entity, ValueType::String, ValueType::U64],
1760            },
1761        ] {
1762            schema
1763                .register_predicate(signature)
1764                .expect("register predicate");
1765        }
1766
1767        let program = RuleProgram {
1768            predicates: vec![
1769                predicate(1, "task", 1),
1770                predicate(2, "execution_attempt", 3),
1771                predicate(3, "task_depends_on", 2),
1772                predicate(4, "task_status", 2),
1773                predicate(5, "task_claimed_by", 2),
1774                predicate(6, "task_lease_epoch", 2),
1775                predicate(7, "task_lease_state", 2),
1776                predicate(8, "task_complete", 1),
1777                predicate(9, "dependency_blocked", 1),
1778                predicate(10, "lease_active", 3),
1779                predicate(11, "active_claim", 1),
1780                predicate(12, "task_ready", 1),
1781                predicate(13, "execution_rejected_stale", 3),
1782            ],
1783            rules: vec![
1784                RuleAst {
1785                    id: RuleId::new(1),
1786                    head: atom(predicate(8, "task_complete", 1), &["t"]),
1787                    body: vec![Literal::Positive(Atom {
1788                        predicate: predicate(4, "task_status", 2),
1789                        terms: vec![
1790                            Term::Variable(Variable::new("t")),
1791                            Term::Value(Value::String("done".into())),
1792                        ],
1793                    })],
1794                },
1795                RuleAst {
1796                    id: RuleId::new(2),
1797                    head: atom(predicate(9, "dependency_blocked", 1), &["t"]),
1798                    body: vec![
1799                        Literal::Positive(atom(predicate(3, "task_depends_on", 2), &["t", "dep"])),
1800                        Literal::Negative(atom(predicate(8, "task_complete", 1), &["dep"])),
1801                    ],
1802                },
1803                RuleAst {
1804                    id: RuleId::new(3),
1805                    head: atom(predicate(10, "lease_active", 3), &["t", "worker", "epoch"]),
1806                    body: vec![
1807                        Literal::Positive(atom(
1808                            predicate(5, "task_claimed_by", 2),
1809                            &["t", "worker"],
1810                        )),
1811                        Literal::Positive(atom(
1812                            predicate(6, "task_lease_epoch", 2),
1813                            &["t", "epoch"],
1814                        )),
1815                        Literal::Positive(Atom {
1816                            predicate: predicate(7, "task_lease_state", 2),
1817                            terms: vec![
1818                                Term::Variable(Variable::new("t")),
1819                                Term::Value(Value::String("active".into())),
1820                            ],
1821                        }),
1822                    ],
1823                },
1824                RuleAst {
1825                    id: RuleId::new(4),
1826                    head: atom(predicate(11, "active_claim", 1), &["t"]),
1827                    body: vec![Literal::Positive(atom(
1828                        predicate(10, "lease_active", 3),
1829                        &["t", "worker", "epoch"],
1830                    ))],
1831                },
1832                RuleAst {
1833                    id: RuleId::new(5),
1834                    head: atom(predicate(12, "task_ready", 1), &["t"]),
1835                    body: vec![
1836                        Literal::Positive(atom(predicate(1, "task", 1), &["t"])),
1837                        Literal::Negative(atom(predicate(9, "dependency_blocked", 1), &["t"])),
1838                        Literal::Negative(atom(predicate(11, "active_claim", 1), &["t"])),
1839                    ],
1840                },
1841                RuleAst {
1842                    id: RuleId::new(6),
1843                    head: atom(
1844                        predicate(13, "execution_rejected_stale", 3),
1845                        &["t", "worker", "epoch"],
1846                    ),
1847                    body: vec![
1848                        Literal::Positive(atom(
1849                            predicate(2, "execution_attempt", 3),
1850                            &["t", "worker", "epoch"],
1851                        )),
1852                        Literal::Negative(atom(
1853                            predicate(10, "lease_active", 3),
1854                            &["t", "worker", "epoch"],
1855                        )),
1856                    ],
1857                },
1858            ],
1859            materialized: vec![PredicateId::new(12), PredicateId::new(13)],
1860            facts: vec![
1861                ExtensionalFact {
1862                    predicate: predicate(1, "task", 1),
1863                    values: vec![Value::Entity(EntityId::new(1))],
1864                    policy: None,
1865                    provenance: None,
1866                },
1867                ExtensionalFact {
1868                    predicate: predicate(1, "task", 1),
1869                    values: vec![Value::Entity(EntityId::new(2))],
1870                    policy: None,
1871                    provenance: None,
1872                },
1873                ExtensionalFact {
1874                    predicate: predicate(2, "execution_attempt", 3),
1875                    values: vec![
1876                        Value::Entity(EntityId::new(1)),
1877                        Value::String("worker-a".into()),
1878                        Value::U64(1),
1879                    ],
1880                    policy: None,
1881                    provenance: None,
1882                },
1883            ],
1884        };
1885        let datoms = vec![
1886            dependency_datom(1, 2, 1),
1887            datom(2, 2, Value::String("done".into()), 2),
1888            datom(1, 3, Value::String("worker-a".into()), 3),
1889            datom(1, 4, Value::U64(1), 4),
1890            datom(1, 5, Value::String("active".into()), 5),
1891            datom(1, 5, Value::String("expired".into()), 6),
1892        ];
1893
1894        let compiled = DefaultRuleCompiler
1895            .compile(&schema, &program)
1896            .expect("compile coordination program");
1897        let as_of_state = MaterializedResolver
1898            .as_of(&schema, &datoms, &ElementId::new(5))
1899            .expect("resolve as_of");
1900        let current_state = MaterializedResolver
1901            .current(&schema, &datoms)
1902            .expect("resolve current");
1903        let as_of_derived = SemiNaiveRuntime
1904            .evaluate(&as_of_state, &compiled)
1905            .expect("evaluate as_of");
1906        let current_derived = SemiNaiveRuntime
1907            .evaluate(&current_state, &compiled)
1908            .expect("evaluate current");
1909
1910        let as_of_ready = execute_query(
1911            &as_of_state,
1912            &compiled,
1913            &as_of_derived,
1914            &QueryAst {
1915                goals: vec![
1916                    atom(predicate(12, "task_ready", 1), &["t"]),
1917                    Atom {
1918                        predicate: predicate(5, "task_claimed_by", 2),
1919                        terms: vec![
1920                            Term::Variable(Variable::new("t")),
1921                            Term::Value(Value::String("worker-a".into())),
1922                        ],
1923                    },
1924                ],
1925                keep: vec![Variable::new("t")],
1926            },
1927            None,
1928        )
1929        .expect("query as_of ready");
1930        assert!(as_of_ready.rows.is_empty());
1931
1932        let current_ready = execute_query(
1933            &current_state,
1934            &compiled,
1935            &current_derived,
1936            &QueryAst {
1937                goals: vec![
1938                    atom(predicate(12, "task_ready", 1), &["t"]),
1939                    Atom {
1940                        predicate: predicate(5, "task_claimed_by", 2),
1941                        terms: vec![
1942                            Term::Variable(Variable::new("t")),
1943                            Term::Value(Value::String("worker-a".into())),
1944                        ],
1945                    },
1946                ],
1947                keep: vec![Variable::new("t")],
1948            },
1949            None,
1950        )
1951        .expect("query current ready");
1952        assert_eq!(current_ready.rows.len(), 1);
1953        assert_eq!(
1954            current_ready.rows[0].values,
1955            vec![Value::Entity(EntityId::new(1))]
1956        );
1957
1958        let stale_attempts = execute_query(
1959            &current_state,
1960            &compiled,
1961            &current_derived,
1962            &QueryAst {
1963                goals: vec![atom(
1964                    predicate(13, "execution_rejected_stale", 3),
1965                    &["t", "worker", "epoch"],
1966                )],
1967                keep: vec![
1968                    Variable::new("t"),
1969                    Variable::new("worker"),
1970                    Variable::new("epoch"),
1971                ],
1972            },
1973            None,
1974        )
1975        .expect("query stale attempts");
1976        assert_eq!(
1977            stale_attempts.rows,
1978            vec![QueryRow {
1979                values: vec![
1980                    Value::Entity(EntityId::new(1)),
1981                    Value::String("worker-a".into()),
1982                    Value::U64(1),
1983                ],
1984                tuple_id: stale_attempts.rows.first().and_then(|row| row.tuple_id),
1985            }]
1986        );
1987    }
1988
1989    #[test]
1990    fn missing_extensional_binding_is_reported() {
1991        let mut schema = Schema::new("v1");
1992        schema
1993            .register_predicate(PredicateSignature {
1994                id: PredicateId::new(10),
1995                name: "edge".into(),
1996                fields: vec![ValueType::Entity, ValueType::Entity],
1997            })
1998            .expect("register edge");
1999        schema
2000            .register_predicate(PredicateSignature {
2001                id: PredicateId::new(11),
2002                name: "reach".into(),
2003                fields: vec![ValueType::Entity, ValueType::Entity],
2004            })
2005            .expect("register reach");
2006        let program = RuleProgram {
2007            predicates: vec![predicate(10, "edge", 2), predicate(11, "reach", 2)],
2008            rules: vec![RuleAst {
2009                id: RuleId::new(1),
2010                head: atom(predicate(11, "reach", 2), &["x", "y"]),
2011                body: vec![Literal::Positive(atom(
2012                    predicate(10, "edge", 2),
2013                    &["x", "y"],
2014                ))],
2015            }],
2016            materialized: vec![PredicateId::new(11)],
2017            facts: Vec::new(),
2018        };
2019        let compiled = DefaultRuleCompiler
2020            .compile(&schema, &program)
2021            .expect("compile unbound program");
2022
2023        let error = SemiNaiveRuntime
2024            .evaluate(&Default::default(), &compiled)
2025            .expect_err("missing extensional binding should fail");
2026        assert!(matches!(
2027            error,
2028            RuntimeError::MissingExtensionalBinding(id) if id == PredicateId::new(10)
2029        ));
2030    }
2031
2032    fn datom(entity: u64, attribute: u64, value: Value, element: u64) -> Datom {
2033        Datom {
2034            entity: EntityId::new(entity),
2035            attribute: AttributeId::new(attribute),
2036            value,
2037            op: aether_ast::OperationKind::Assert,
2038            element: ElementId::new(element),
2039            replica: aether_ast::ReplicaId::new(1),
2040            causal_context: Default::default(),
2041            provenance: DatomProvenance::default(),
2042            policy: None,
2043        }
2044    }
2045}