1use aether_ast::{
2 merge_partition_cuts, merge_policy_envelopes, policy_allows, AggregateFunction, AggregateTerm,
3 DerivedTuple, DerivedTupleMetadata, ElementId, Literal, PartitionCut, PolicyContext,
4 PolicyEnvelope, PredicateId, QueryAst, QueryResult, QueryRow, RuleAst, RuleId, Term, Tuple,
5 TupleId, Value, Variable,
6};
7use aether_plan::CompiledProgram;
8use aether_resolver::ResolvedState;
9use indexmap::{IndexMap, IndexSet};
10use serde::{Deserialize, Serialize};
11use std::cmp::Ordering;
12use thiserror::Error;
13
14pub trait RuleRuntime {
15 fn evaluate(
16 &self,
17 state: &ResolvedState,
18 program: &CompiledProgram,
19 ) -> Result<DerivedSet, RuntimeError>;
20}
21
22#[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize)]
23pub struct RuntimeIteration {
24 pub iteration: usize,
25 pub delta_size: usize,
26}
27
28#[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize)]
29pub struct DerivedSet {
30 pub tuples: Vec<DerivedTuple>,
31 pub iterations: Vec<RuntimeIteration>,
32 pub predicate_index: IndexMap<PredicateId, Vec<TupleId>>,
33}
34
35impl DerivedSet {
36 pub fn has_converged(&self) -> bool {
37 match self.iterations.last() {
38 Some(iteration) => iteration.delta_size == 0,
39 None => true,
40 }
41 }
42}
43
44#[derive(Clone, Debug, Default)]
45struct RelationRow {
46 values: Vec<Value>,
47 tuple_id: Option<TupleId>,
48 source_datom_ids: Vec<ElementId>,
49 imported_cuts: Vec<PartitionCut>,
50 policy: Option<PolicyEnvelope>,
51}
52
53#[derive(Clone, Debug, Default)]
54struct MatchState {
55 bindings: IndexMap<Variable, Value>,
56 parent_tuple_ids: Vec<TupleId>,
57 source_datom_ids: Vec<ElementId>,
58 imported_cuts: Vec<PartitionCut>,
59 query_tuple_id: Option<TupleId>,
60 policy: Option<PolicyEnvelope>,
61}
62
63#[derive(Clone, Debug)]
64struct AggregatedMatch {
65 values: Vec<Value>,
66 parent_tuple_ids: Vec<TupleId>,
67 source_datom_ids: Vec<ElementId>,
68 imported_cuts: Vec<PartitionCut>,
69 policy: Option<PolicyEnvelope>,
70}
71
72#[derive(Clone, Debug)]
73struct AggregateGroup {
74 values: Vec<Option<Value>>,
75 accumulators: Vec<AggregateAccumulator>,
76 seen_bindings: IndexSet<String>,
77 parent_tuple_ids: Vec<TupleId>,
78 source_datom_ids: Vec<ElementId>,
79 imported_cuts: Vec<PartitionCut>,
80 policy: Option<PolicyEnvelope>,
81}
82
83#[derive(Clone, Debug)]
84enum AggregateAccumulator {
85 Count(u64),
86 SumI64(i64),
87 SumU64(u64),
88 SumF64(f64),
89 Min(Value),
90 Max(Value),
91}
92
93#[derive(Default)]
94pub struct SemiNaiveRuntime;
95
96impl RuleRuntime for SemiNaiveRuntime {
97 fn evaluate(
98 &self,
99 state: &ResolvedState,
100 program: &CompiledProgram,
101 ) -> Result<DerivedSet, RuntimeError> {
102 let extensional_rows = build_extensional_rows(state, program);
103 let intensional_predicates: IndexSet<PredicateId> = program
104 .rules
105 .iter()
106 .map(|rule| rule.head.predicate.id)
107 .collect();
108 let scc_lookup = build_scc_lookup(program);
109 let scc_order = build_scc_evaluation_order(program, &scc_lookup);
110 let rules_by_scc = build_rules_by_scc(program, &scc_lookup);
111
112 let mut derived_by_predicate: IndexMap<PredicateId, Vec<RelationRow>> = IndexMap::new();
113 let mut tuple_keys = IndexSet::new();
114 let mut tuples = Vec::new();
115 let mut iterations = Vec::new();
116 let mut next_tuple_id = 1u64;
117 let mut iteration = 1usize;
118
119 for scc_id in scc_order {
120 let Some(rules) = rules_by_scc.get(&scc_id) else {
121 continue;
122 };
123 let current_scc_predicates: IndexSet<PredicateId> =
124 rules.iter().map(|rule| rule.head.predicate.id).collect();
125 let stratum = rules
126 .first()
127 .and_then(|rule| program.predicate_strata.get(&rule.head.predicate.id))
128 .copied()
129 .unwrap_or_default();
130
131 let mut delta_rows: IndexMap<PredicateId, Vec<RelationRow>> = IndexMap::new();
132 loop {
133 let mut batch_rows: IndexMap<PredicateId, Vec<RelationRow>> = IndexMap::new();
134 let mut batch_tuples = Vec::new();
135
136 for rule in rules {
137 let aggregates = head_aggregates(rule);
138 let anchor_indices = if aggregates.is_empty() {
139 current_scc_positive_indices(rule, ¤t_scc_predicates)
140 } else {
141 Vec::new()
142 };
143 let anchor_plan = if delta_rows.is_empty() {
144 if anchor_indices.is_empty() {
145 vec![None]
146 } else {
147 Vec::new()
148 }
149 } else if anchor_indices.is_empty() {
150 Vec::new()
151 } else {
152 anchor_indices.into_iter().map(Some).collect()
153 };
154
155 let mut aggregate_matches = Vec::new();
156
157 for anchor_index in anchor_plan {
158 let matches = evaluate_rule_body_variant(
159 rule,
160 anchor_index,
161 &derived_by_predicate,
162 &delta_rows,
163 &extensional_rows,
164 &intensional_predicates,
165 ¤t_scc_predicates,
166 )?;
167
168 if !aggregates.is_empty() {
169 aggregate_matches.extend(matches);
170 continue;
171 }
172
173 for matched in matches {
174 let values = materialize_non_aggregate_head(
175 rule.id,
176 &rule.head.terms,
177 &matched.bindings,
178 )?;
179 let key = tuple_key(rule.head.predicate.id, &values);
180 if tuple_keys.contains(&key) {
181 continue;
182 }
183
184 let tuple_id = TupleId::new(next_tuple_id);
185 next_tuple_id += 1;
186 tuple_keys.insert(key);
187
188 batch_rows.entry(rule.head.predicate.id).or_default().push(
189 RelationRow {
190 values: values.clone(),
191 tuple_id: Some(tuple_id),
192 source_datom_ids: matched.source_datom_ids.clone(),
193 imported_cuts: matched.imported_cuts.clone(),
194 policy: matched.policy.clone(),
195 },
196 );
197 batch_tuples.push(DerivedTuple {
198 tuple: Tuple {
199 id: tuple_id,
200 predicate: rule.head.predicate.id,
201 values,
202 },
203 metadata: DerivedTupleMetadata {
204 rule_id: rule.id,
205 predicate_id: rule.head.predicate.id,
206 stratum,
207 scc_id,
208 iteration,
209 parent_tuple_ids: matched.parent_tuple_ids,
210 source_datom_ids: matched.source_datom_ids,
211 imported_cuts: matched.imported_cuts,
212 },
213 policy: matched.policy,
214 });
215 }
216 }
217
218 if !aggregates.is_empty() {
219 let matches = materialize_aggregate_head(
220 rule.id,
221 &rule.head.terms,
222 &aggregates,
223 &aggregate_matches,
224 )?;
225 for matched in matches {
226 let key = tuple_key(rule.head.predicate.id, &matched.values);
227 if tuple_keys.contains(&key) {
228 continue;
229 }
230
231 let tuple_id = TupleId::new(next_tuple_id);
232 next_tuple_id += 1;
233 tuple_keys.insert(key);
234
235 batch_rows.entry(rule.head.predicate.id).or_default().push(
236 RelationRow {
237 values: matched.values.clone(),
238 tuple_id: Some(tuple_id),
239 source_datom_ids: matched.source_datom_ids.clone(),
240 imported_cuts: matched.imported_cuts.clone(),
241 policy: matched.policy.clone(),
242 },
243 );
244 batch_tuples.push(DerivedTuple {
245 tuple: Tuple {
246 id: tuple_id,
247 predicate: rule.head.predicate.id,
248 values: matched.values,
249 },
250 metadata: DerivedTupleMetadata {
251 rule_id: rule.id,
252 predicate_id: rule.head.predicate.id,
253 stratum,
254 scc_id,
255 iteration,
256 parent_tuple_ids: matched.parent_tuple_ids,
257 source_datom_ids: matched.source_datom_ids,
258 imported_cuts: matched.imported_cuts,
259 },
260 policy: matched.policy,
261 });
262 }
263 }
264 }
265
266 if batch_tuples.is_empty() {
267 break;
268 }
269
270 iterations.push(RuntimeIteration {
271 iteration,
272 delta_size: batch_tuples.len(),
273 });
274 iteration += 1;
275
276 for (predicate, rows) in &batch_rows {
277 derived_by_predicate
278 .entry(*predicate)
279 .or_default()
280 .extend(rows.iter().cloned());
281 }
282 tuples.extend(batch_tuples);
283 delta_rows = batch_rows;
284 }
285 }
286
287 iterations.push(RuntimeIteration {
288 iteration,
289 delta_size: 0,
290 });
291
292 let mut predicate_index = program
293 .materialized
294 .iter()
295 .copied()
296 .map(|predicate| (predicate, Vec::new()))
297 .collect::<IndexMap<_, _>>();
298 for tuple in &tuples {
299 predicate_index
300 .entry(tuple.tuple.predicate)
301 .or_default()
302 .push(tuple.tuple.id);
303 }
304
305 Ok(DerivedSet {
306 tuples,
307 iterations,
308 predicate_index,
309 })
310 }
311}
312
313pub fn execute_query(
314 state: &ResolvedState,
315 program: &CompiledProgram,
316 derived: &DerivedSet,
317 query: &QueryAst,
318 policy_context: Option<&PolicyContext>,
319) -> Result<QueryResult, RuntimeError> {
320 let extensional_rows = build_extensional_rows(state, program);
321 let intensional_predicates: IndexSet<PredicateId> = program
322 .rules
323 .iter()
324 .map(|rule| rule.head.predicate.id)
325 .collect();
326 let derived_rows = build_derived_rows(derived);
327
328 let mut states = vec![MatchState::default()];
329 for goal in &query.goals {
330 let rows = positive_relation_rows(
331 goal.predicate.id,
332 None,
333 &derived_rows,
334 &IndexMap::new(),
335 &extensional_rows,
336 &intensional_predicates,
337 &IndexSet::new(),
338 )?;
339 let mut next_states = Vec::new();
340
341 for state in &states {
342 for row in &rows {
343 if !policy_allows(policy_context, row.policy.as_ref()) {
344 continue;
345 }
346 if let Some(bindings) = unify_terms(&state.bindings, &goal.terms, &row.values) {
347 next_states.push(MatchState {
348 bindings,
349 parent_tuple_ids: state.parent_tuple_ids.clone(),
350 source_datom_ids: state.source_datom_ids.clone(),
351 imported_cuts: merge_partition_cuts(
352 state.imported_cuts.iter().chain(row.imported_cuts.iter()),
353 ),
354 query_tuple_id: row.tuple_id.or(state.query_tuple_id),
355 policy: merge_policy_envelopes([
356 state.policy.as_ref(),
357 row.policy.as_ref(),
358 ]),
359 });
360 }
361 }
362 }
363
364 states = next_states;
365 if states.is_empty() {
366 break;
367 }
368 }
369
370 let mut rows = states
371 .into_iter()
372 .filter(|state| policy_allows(policy_context, state.policy.as_ref()))
373 .map(|state| QueryRow {
374 values: if query.keep.is_empty() {
375 state.bindings.values().cloned().collect()
376 } else {
377 query
378 .keep
379 .iter()
380 .filter_map(|variable| state.bindings.get(variable).cloned())
381 .collect()
382 },
383 tuple_id: state.query_tuple_id,
384 })
385 .collect::<Vec<_>>();
386 rows.sort_by_key(|row| {
387 let mut key = String::new();
388 for value in &row.values {
389 key.push_str(&value_key(value));
390 key.push('|');
391 }
392 key
393 });
394
395 Ok(QueryResult { rows })
396}
397
398fn build_extensional_rows(
399 state: &ResolvedState,
400 program: &CompiledProgram,
401) -> IndexMap<PredicateId, Vec<RelationRow>> {
402 let mut rows: IndexMap<PredicateId, Vec<RelationRow>> = IndexMap::new();
403
404 for (predicate, attribute) in &program.extensional_bindings {
405 let mut predicate_rows = Vec::new();
406 for (entity_id, entity_state) in &state.entities {
407 predicate_rows.extend(entity_state.facts(attribute).iter().cloned().map(|fact| {
408 RelationRow {
409 values: vec![Value::Entity(*entity_id), fact.value],
410 tuple_id: None,
411 source_datom_ids: fact.source_datom_ids,
412 imported_cuts: Vec::new(),
413 policy: fact.policy,
414 }
415 }));
416 }
417 rows.entry(*predicate).or_default().extend(predicate_rows);
418 }
419
420 for fact in &program.facts {
421 rows.entry(fact.predicate.id)
422 .or_default()
423 .push(RelationRow {
424 values: fact.values.clone(),
425 tuple_id: None,
426 source_datom_ids: fact
427 .provenance
428 .as_ref()
429 .map(|provenance| provenance.source_datom_ids.clone())
430 .unwrap_or_default(),
431 imported_cuts: fact
432 .provenance
433 .as_ref()
434 .map(|provenance| provenance.imported_cuts.clone())
435 .unwrap_or_default(),
436 policy: fact.policy.clone(),
437 });
438 }
439
440 rows
441}
442
443fn build_derived_rows(derived: &DerivedSet) -> IndexMap<PredicateId, Vec<RelationRow>> {
444 let mut rows: IndexMap<PredicateId, Vec<RelationRow>> = IndexMap::new();
445 for tuple in &derived.tuples {
446 rows.entry(tuple.tuple.predicate)
447 .or_default()
448 .push(RelationRow {
449 values: tuple.tuple.values.clone(),
450 tuple_id: Some(tuple.tuple.id),
451 source_datom_ids: tuple.metadata.source_datom_ids.clone(),
452 imported_cuts: tuple.metadata.imported_cuts.clone(),
453 policy: tuple.policy.clone(),
454 });
455 }
456 rows
457}
458
459fn build_rules_by_scc<'a>(
460 program: &'a CompiledProgram,
461 scc_lookup: &IndexMap<PredicateId, usize>,
462) -> IndexMap<usize, Vec<&'a RuleAst>> {
463 let mut rules = IndexMap::new();
464 for rule in &program.rules {
465 let scc_id = *scc_lookup
466 .get(&rule.head.predicate.id)
467 .expect("rule head predicate should be present in scc lookup");
468 rules.entry(scc_id).or_insert_with(Vec::new).push(rule);
469 }
470 rules
471}
472
473fn build_scc_evaluation_order(
474 program: &CompiledProgram,
475 scc_lookup: &IndexMap<PredicateId, usize>,
476) -> Vec<usize> {
477 let mut edges = IndexSet::new();
478 let mut indegree = program
479 .sccs
480 .iter()
481 .map(|scc| (scc.id, 0usize))
482 .collect::<IndexMap<_, _>>();
483 let mut outgoing = program
484 .sccs
485 .iter()
486 .map(|scc| (scc.id, Vec::new()))
487 .collect::<IndexMap<_, _>>();
488
489 for (head, dependencies) in &program.dependency_graph.edges {
490 let head_scc = *scc_lookup
491 .get(head)
492 .expect("head predicate should be present in scc lookup");
493 for dependency in dependencies {
494 let dependency_scc = *scc_lookup
495 .get(dependency)
496 .expect("dependency predicate should be present in scc lookup");
497 if dependency_scc != head_scc && edges.insert((dependency_scc, head_scc)) {
498 outgoing.entry(dependency_scc).or_default().push(head_scc);
499 *indegree.entry(head_scc).or_default() += 1;
500 }
501 }
502 }
503
504 let scc_strata = program
505 .sccs
506 .iter()
507 .map(|scc| {
508 let stratum = scc
509 .predicates
510 .first()
511 .and_then(|predicate| program.predicate_strata.get(predicate))
512 .copied()
513 .unwrap_or_default();
514 (scc.id, stratum)
515 })
516 .collect::<IndexMap<_, _>>();
517 let mut ready = indegree
518 .iter()
519 .filter_map(|(scc_id, degree)| (*degree == 0).then_some(*scc_id))
520 .collect::<Vec<_>>();
521 ready.sort_by_key(|scc_id| (scc_strata.get(scc_id).copied().unwrap_or_default(), *scc_id));
522
523 let mut order = Vec::new();
524 while let Some(scc_id) = ready.first().copied() {
525 ready.remove(0);
526 order.push(scc_id);
527 if let Some(neighbors) = outgoing.get(&scc_id) {
528 for neighbor in neighbors {
529 let degree = indegree
530 .get_mut(neighbor)
531 .expect("neighbor scc should have indegree");
532 *degree -= 1;
533 if *degree == 0 {
534 ready.push(*neighbor);
535 ready.sort_by_key(|candidate| {
536 (
537 scc_strata.get(candidate).copied().unwrap_or_default(),
538 *candidate,
539 )
540 });
541 }
542 }
543 }
544 }
545
546 order
547}
548
549fn current_scc_positive_indices(
550 rule: &RuleAst,
551 current_scc_predicates: &IndexSet<PredicateId>,
552) -> Vec<usize> {
553 rule.body
554 .iter()
555 .enumerate()
556 .filter_map(|(index, literal)| match literal {
557 Literal::Positive(atom) if current_scc_predicates.contains(&atom.predicate.id) => {
558 Some(index)
559 }
560 _ => None,
561 })
562 .collect()
563}
564
565fn evaluate_rule_body_variant(
566 rule: &RuleAst,
567 delta_anchor_index: Option<usize>,
568 derived_rows: &IndexMap<PredicateId, Vec<RelationRow>>,
569 delta_rows: &IndexMap<PredicateId, Vec<RelationRow>>,
570 extensional_rows: &IndexMap<PredicateId, Vec<RelationRow>>,
571 intensional_predicates: &IndexSet<PredicateId>,
572 current_scc_predicates: &IndexSet<PredicateId>,
573) -> Result<Vec<MatchState>, RuntimeError> {
574 let mut states = vec![MatchState::default()];
575
576 for (literal_index, literal) in ordered_rule_body(rule) {
577 match literal {
578 Literal::Positive(atom) => {
579 let rows = positive_relation_rows(
580 atom.predicate.id,
581 (delta_anchor_index == Some(literal_index)).then_some(()),
582 derived_rows,
583 delta_rows,
584 extensional_rows,
585 intensional_predicates,
586 current_scc_predicates,
587 )?;
588 let mut next_states = Vec::new();
589
590 for state in &states {
591 for row in &rows {
592 if let Some(bindings) =
593 unify_terms(&state.bindings, &atom.terms, &row.values)
594 {
595 let mut parent_tuple_ids = state.parent_tuple_ids.clone();
596 if let Some(tuple_id) = row.tuple_id {
597 if !parent_tuple_ids.contains(&tuple_id) {
598 parent_tuple_ids.push(tuple_id);
599 }
600 }
601 let mut source_datom_ids = state.source_datom_ids.clone();
602 extend_unique(&mut source_datom_ids, &row.source_datom_ids);
603 next_states.push(MatchState {
604 bindings,
605 parent_tuple_ids,
606 source_datom_ids,
607 imported_cuts: merge_partition_cuts(
608 state.imported_cuts.iter().chain(row.imported_cuts.iter()),
609 ),
610 query_tuple_id: row.tuple_id.or(state.query_tuple_id),
611 policy: merge_policy_envelopes([
612 state.policy.as_ref(),
613 row.policy.as_ref(),
614 ]),
615 });
616 }
617 }
618 }
619
620 states = next_states;
621 }
622 Literal::Negative(atom) => {
623 if current_scc_predicates.contains(&atom.predicate.id) {
624 return Err(RuntimeError::UnsupportedIntraStratumNegation(rule.id));
625 }
626 let rows = negative_relation_rows(
627 atom.predicate.id,
628 derived_rows,
629 extensional_rows,
630 intensional_predicates,
631 )?;
632 states.retain(|state| {
633 !rows
634 .iter()
635 .any(|row| unify_terms(&state.bindings, &atom.terms, &row.values).is_some())
636 });
637 }
638 }
639
640 if states.is_empty() {
641 break;
642 }
643 }
644
645 Ok(states)
646}
647
648fn ordered_rule_body(rule: &RuleAst) -> Vec<(usize, &Literal)> {
649 let mut positives = Vec::new();
650 let mut negatives = Vec::new();
651 for (index, literal) in rule.body.iter().enumerate() {
652 match literal {
653 Literal::Positive(_) => positives.push((index, literal)),
654 Literal::Negative(_) => negatives.push((index, literal)),
655 }
656 }
657 positives.extend(negatives);
658 positives
659}
660
661fn positive_relation_rows(
662 predicate: PredicateId,
663 use_delta: Option<()>,
664 derived_rows: &IndexMap<PredicateId, Vec<RelationRow>>,
665 delta_rows: &IndexMap<PredicateId, Vec<RelationRow>>,
666 extensional_rows: &IndexMap<PredicateId, Vec<RelationRow>>,
667 intensional_predicates: &IndexSet<PredicateId>,
668 _current_scc_predicates: &IndexSet<PredicateId>,
669) -> Result<Vec<RelationRow>, RuntimeError> {
670 if use_delta.is_some() {
671 return Ok(delta_rows.get(&predicate).cloned().unwrap_or_default());
672 }
673 if intensional_predicates.contains(&predicate) {
674 return Ok(derived_rows.get(&predicate).cloned().unwrap_or_default());
675 }
676
677 extensional_rows
678 .get(&predicate)
679 .cloned()
680 .ok_or(RuntimeError::MissingExtensionalBinding(predicate))
681}
682
683fn negative_relation_rows(
684 predicate: PredicateId,
685 derived_rows: &IndexMap<PredicateId, Vec<RelationRow>>,
686 extensional_rows: &IndexMap<PredicateId, Vec<RelationRow>>,
687 intensional_predicates: &IndexSet<PredicateId>,
688) -> Result<Vec<RelationRow>, RuntimeError> {
689 if intensional_predicates.contains(&predicate) {
690 Ok(derived_rows.get(&predicate).cloned().unwrap_or_default())
691 } else {
692 extensional_rows
693 .get(&predicate)
694 .cloned()
695 .ok_or(RuntimeError::MissingExtensionalBinding(predicate))
696 }
697}
698
699fn unify_terms(
700 bindings: &IndexMap<Variable, Value>,
701 terms: &[Term],
702 values: &[Value],
703) -> Option<IndexMap<Variable, Value>> {
704 if terms.len() != values.len() {
705 return None;
706 }
707
708 let mut next_bindings = bindings.clone();
709 for (term, value) in terms.iter().zip(values) {
710 match term {
711 Term::Variable(variable) => match next_bindings.get(variable) {
712 Some(bound) if bound != value => return None,
713 Some(_) => {}
714 None => {
715 next_bindings.insert(variable.clone(), value.clone());
716 }
717 },
718 Term::Value(expected) if expected != value => return None,
719 Term::Value(_) => {}
720 Term::Aggregate(_) => return None,
721 }
722 }
723
724 Some(next_bindings)
725}
726
727fn materialize_non_aggregate_head(
728 rule_id: RuleId,
729 terms: &[Term],
730 bindings: &IndexMap<Variable, Value>,
731) -> Result<Vec<Value>, RuntimeError> {
732 terms
733 .iter()
734 .map(|term| match term {
735 Term::Variable(variable) => {
736 bindings
737 .get(variable)
738 .cloned()
739 .ok_or_else(|| RuntimeError::UnboundVariable {
740 rule_id,
741 variable: variable.0.clone(),
742 })
743 }
744 Term::Value(value) => Ok(value.clone()),
745 Term::Aggregate(_) => Err(RuntimeError::UnexpectedAggregate(rule_id)),
746 })
747 .collect()
748}
749
750fn materialize_aggregate_head(
751 rule_id: RuleId,
752 terms: &[Term],
753 aggregates: &[(usize, &AggregateTerm)],
754 matches: &[MatchState],
755) -> Result<Vec<AggregatedMatch>, RuntimeError> {
756 let mut groups: IndexMap<String, AggregateGroup> = IndexMap::new();
757
758 for matched in matches {
759 let binding_key = bindings_key(&matched.bindings);
760 let group_values = materialize_group_values(rule_id, terms, aggregates, &matched.bindings)?;
761 let group_key = values_key(&group_values);
762
763 if !groups.contains_key(&group_key) {
764 let accumulators = aggregates
765 .iter()
766 .map(|(_, aggregate_term)| {
767 let aggregate_value = matched
768 .bindings
769 .get(&aggregate_term.variable)
770 .ok_or_else(|| RuntimeError::UnboundVariable {
771 rule_id,
772 variable: aggregate_term.variable.0.clone(),
773 })?;
774 AggregateAccumulator::from_value(
775 rule_id,
776 aggregate_term.function,
777 aggregate_value,
778 )
779 })
780 .collect::<Result<Vec<_>, _>>()?;
781 groups.insert(
782 group_key.clone(),
783 AggregateGroup {
784 values: group_values.into_iter().map(Some).collect(),
785 accumulators,
786 seen_bindings: IndexSet::new(),
787 parent_tuple_ids: Vec::new(),
788 source_datom_ids: Vec::new(),
789 imported_cuts: Vec::new(),
790 policy: None,
791 },
792 );
793 }
794 let group = groups
795 .get_mut(&group_key)
796 .expect("aggregate group should exist after insertion");
797
798 if !group.seen_bindings.insert(binding_key) {
799 continue;
800 }
801
802 if group.seen_bindings.len() > 1 {
803 for (accumulator, (_, aggregate_term)) in
804 group.accumulators.iter_mut().zip(aggregates.iter())
805 {
806 let aggregate_value =
807 matched
808 .bindings
809 .get(&aggregate_term.variable)
810 .ok_or_else(|| RuntimeError::UnboundVariable {
811 rule_id,
812 variable: aggregate_term.variable.0.clone(),
813 })?;
814 accumulator.add(rule_id, aggregate_term.function, aggregate_value)?;
815 }
816 }
817 extend_unique(&mut group.parent_tuple_ids, &matched.parent_tuple_ids);
818 extend_unique(&mut group.source_datom_ids, &matched.source_datom_ids);
819 group.imported_cuts = merge_partition_cuts(
820 group
821 .imported_cuts
822 .iter()
823 .chain(matched.imported_cuts.iter()),
824 );
825 group.policy = merge_policy_envelopes([group.policy.as_ref(), matched.policy.as_ref()]);
826 }
827
828 let mut aggregated = groups
829 .into_values()
830 .map(|mut group| {
831 let mut values = group
832 .values
833 .into_iter()
834 .map(|value| value.expect("group values are initialized"))
835 .collect::<Vec<_>>();
836 for ((aggregate_index, _), accumulator) in
837 aggregates.iter().zip(group.accumulators.drain(..))
838 {
839 values[*aggregate_index] = accumulator.finalize();
840 }
841 AggregatedMatch {
842 values,
843 parent_tuple_ids: group.parent_tuple_ids,
844 source_datom_ids: group.source_datom_ids,
845 imported_cuts: group.imported_cuts,
846 policy: group.policy,
847 }
848 })
849 .collect::<Vec<_>>();
850
851 aggregated.sort_by_key(|group| values_key(&group.values));
852 Ok(aggregated)
853}
854
855fn materialize_group_values(
856 rule_id: RuleId,
857 terms: &[Term],
858 aggregates: &[(usize, &AggregateTerm)],
859 bindings: &IndexMap<Variable, Value>,
860) -> Result<Vec<Value>, RuntimeError> {
861 terms
862 .iter()
863 .enumerate()
864 .map(|(index, term)| {
865 if aggregates
866 .iter()
867 .any(|(aggregate_index, _)| index == *aggregate_index)
868 {
869 return Ok(Value::Null);
870 }
871 match term {
872 Term::Variable(variable) => {
873 bindings
874 .get(variable)
875 .cloned()
876 .ok_or_else(|| RuntimeError::UnboundVariable {
877 rule_id,
878 variable: variable.0.clone(),
879 })
880 }
881 Term::Value(value) => Ok(value.clone()),
882 Term::Aggregate(_) => Err(RuntimeError::UnexpectedAggregate(rule_id)),
883 }
884 })
885 .collect()
886}
887
888fn head_aggregates(rule: &RuleAst) -> Vec<(usize, &AggregateTerm)> {
889 rule.head
890 .terms
891 .iter()
892 .enumerate()
893 .filter_map(|(index, term)| match term {
894 Term::Aggregate(aggregate) => Some((index, aggregate)),
895 _ => None,
896 })
897 .collect()
898}
899
900fn bindings_key(bindings: &IndexMap<Variable, Value>) -> String {
901 let mut entries = bindings
902 .iter()
903 .map(|(variable, value)| (variable.0.as_str(), value_key(value)))
904 .collect::<Vec<_>>();
905 entries.sort_unstable_by(|left, right| left.0.cmp(right.0));
906
907 let mut rendered = String::new();
908 for (variable, value) in entries {
909 rendered.push_str(variable);
910 rendered.push('=');
911 rendered.push_str(&value);
912 rendered.push('|');
913 }
914 rendered
915}
916
917fn values_key(values: &[Value]) -> String {
918 let mut rendered = String::new();
919 for value in values {
920 rendered.push_str(&value_key(value));
921 rendered.push('|');
922 }
923 rendered
924}
925
926fn build_scc_lookup(program: &CompiledProgram) -> IndexMap<PredicateId, usize> {
927 let mut lookup = IndexMap::new();
928 for scc in &program.sccs {
929 for predicate in &scc.predicates {
930 lookup.insert(*predicate, scc.id);
931 }
932 }
933 lookup
934}
935
936impl AggregateAccumulator {
937 fn from_value(
938 rule_id: RuleId,
939 function: AggregateFunction,
940 value: &Value,
941 ) -> Result<Self, RuntimeError> {
942 match function {
943 AggregateFunction::Count => Ok(Self::Count(1)),
944 AggregateFunction::Sum => match value {
945 Value::I64(inner) => Ok(Self::SumI64(*inner)),
946 Value::U64(inner) => Ok(Self::SumU64(*inner)),
947 Value::F64(inner) => Ok(Self::SumF64(*inner)),
948 other => Err(RuntimeError::UnsupportedAggregateInput {
949 rule_id,
950 function,
951 actual: runtime_value_type(other),
952 }),
953 },
954 AggregateFunction::Min => {
955 validate_orderable_input(rule_id, function, value).map(|_| Self::Min(value.clone()))
956 }
957 AggregateFunction::Max => {
958 validate_orderable_input(rule_id, function, value).map(|_| Self::Max(value.clone()))
959 }
960 }
961 }
962
963 fn add(
964 &mut self,
965 rule_id: RuleId,
966 function: AggregateFunction,
967 value: &Value,
968 ) -> Result<(), RuntimeError> {
969 match self {
970 Self::Count(count) => {
971 *count += 1;
972 Ok(())
973 }
974 Self::SumI64(total) => match value {
975 Value::I64(inner) => {
976 *total += inner;
977 Ok(())
978 }
979 other => Err(RuntimeError::AggregateInputTypeMismatch {
980 rule_id,
981 function,
982 expected: "I64".into(),
983 actual: runtime_value_type(other),
984 }),
985 },
986 Self::SumU64(total) => match value {
987 Value::U64(inner) => {
988 *total += inner;
989 Ok(())
990 }
991 other => Err(RuntimeError::AggregateInputTypeMismatch {
992 rule_id,
993 function,
994 expected: "U64".into(),
995 actual: runtime_value_type(other),
996 }),
997 },
998 Self::SumF64(total) => match value {
999 Value::F64(inner) => {
1000 *total += inner;
1001 Ok(())
1002 }
1003 other => Err(RuntimeError::AggregateInputTypeMismatch {
1004 rule_id,
1005 function,
1006 expected: "F64".into(),
1007 actual: runtime_value_type(other),
1008 }),
1009 },
1010 Self::Min(current) => {
1011 validate_orderable_input(rule_id, function, value)?;
1012 if compare_values(current, value)? == Ordering::Greater {
1013 *current = value.clone();
1014 }
1015 Ok(())
1016 }
1017 Self::Max(current) => {
1018 validate_orderable_input(rule_id, function, value)?;
1019 if compare_values(current, value)? == Ordering::Less {
1020 *current = value.clone();
1021 }
1022 Ok(())
1023 }
1024 }
1025 }
1026
1027 fn finalize(self) -> Value {
1028 match self {
1029 Self::Count(inner) => Value::U64(inner),
1030 Self::SumI64(inner) => Value::I64(inner),
1031 Self::SumU64(inner) => Value::U64(inner),
1032 Self::SumF64(inner) => Value::F64(inner),
1033 Self::Min(inner) | Self::Max(inner) => inner,
1034 }
1035 }
1036}
1037
1038fn validate_orderable_input(
1039 rule_id: RuleId,
1040 function: AggregateFunction,
1041 value: &Value,
1042) -> Result<(), RuntimeError> {
1043 match value {
1044 Value::I64(_) | Value::U64(_) | Value::F64(_) | Value::String(_) | Value::Entity(_) => {
1045 Ok(())
1046 }
1047 other => Err(RuntimeError::UnsupportedAggregateInput {
1048 rule_id,
1049 function,
1050 actual: runtime_value_type(other),
1051 }),
1052 }
1053}
1054
1055fn compare_values(left: &Value, right: &Value) -> Result<Ordering, RuntimeError> {
1056 match (left, right) {
1057 (Value::I64(left_inner), Value::I64(right_inner)) => Ok(left_inner.cmp(right_inner)),
1058 (Value::U64(left_inner), Value::U64(right_inner)) => Ok(left_inner.cmp(right_inner)),
1059 (Value::F64(left_inner), Value::F64(right_inner)) => left_inner
1060 .partial_cmp(right_inner)
1061 .ok_or_else(|| RuntimeError::NonComparableAggregateValues {
1062 left: runtime_value_type(left),
1063 right: runtime_value_type(right),
1064 }),
1065 (Value::String(left_inner), Value::String(right_inner)) => Ok(left_inner.cmp(right_inner)),
1066 (Value::Entity(left_inner), Value::Entity(right_inner)) => Ok(left_inner.cmp(right_inner)),
1067 _ => Err(RuntimeError::NonComparableAggregateValues {
1068 left: runtime_value_type(left),
1069 right: runtime_value_type(right),
1070 }),
1071 }
1072}
1073
1074fn tuple_key(predicate: PredicateId, values: &[Value]) -> String {
1075 let mut key = format!("{}#", predicate.0);
1076 for value in values {
1077 key.push_str(&value_key(value));
1078 key.push('|');
1079 }
1080 key
1081}
1082
1083fn extend_unique<T>(target: &mut Vec<T>, additions: &[T])
1084where
1085 T: Copy + Eq,
1086{
1087 for addition in additions {
1088 if !target.contains(addition) {
1089 target.push(*addition);
1090 }
1091 }
1092}
1093
1094fn value_key(value: &Value) -> String {
1095 match value {
1096 Value::Null => "null".into(),
1097 Value::Bool(inner) => format!("bool:{inner}"),
1098 Value::I64(inner) => format!("i64:{inner}"),
1099 Value::U64(inner) => format!("u64:{inner}"),
1100 Value::F64(inner) => format!("f64:{:016x}", inner.to_bits()),
1101 Value::String(inner) => format!("string:{}:{inner}", inner.len()),
1102 Value::Bytes(inner) => format!("bytes:{inner:?}"),
1103 Value::Entity(inner) => format!("entity:{}", inner.0),
1104 Value::List(inner) => {
1105 let mut rendered = String::from("list:[");
1106 for value in inner {
1107 rendered.push_str(&value_key(value));
1108 rendered.push(',');
1109 }
1110 rendered.push(']');
1111 rendered
1112 }
1113 }
1114}
1115
1116fn runtime_value_type(value: &Value) -> String {
1117 match value {
1118 Value::Null => "Null".into(),
1119 Value::Bool(_) => "Bool".into(),
1120 Value::I64(_) => "I64".into(),
1121 Value::U64(_) => "U64".into(),
1122 Value::F64(_) => "F64".into(),
1123 Value::String(_) => "String".into(),
1124 Value::Bytes(_) => "Bytes".into(),
1125 Value::Entity(_) => "Entity".into(),
1126 Value::List(_) => "List".into(),
1127 }
1128}
1129
1130#[derive(Debug, Error)]
1131pub enum RuntimeError {
1132 #[error("predicate {0} has no extensional binding or fact rows in the compiled program")]
1133 MissingExtensionalBinding(PredicateId),
1134 #[error("rule {0} uses same-stratum negation, which is not supported")]
1135 UnsupportedIntraStratumNegation(RuleId),
1136 #[error("rule {rule_id} references unbound variable {variable}")]
1137 UnboundVariable { rule_id: RuleId, variable: String },
1138 #[error(
1139 "rule {0} requires grouped aggregate materialization, but was evaluated as a plain rule"
1140 )]
1141 UnexpectedAggregate(RuleId),
1142 #[error(
1143 "rule {rule_id} uses aggregate {function} over unsupported runtime value type {actual}"
1144 )]
1145 UnsupportedAggregateInput {
1146 rule_id: RuleId,
1147 function: AggregateFunction,
1148 actual: String,
1149 },
1150 #[error(
1151 "rule {rule_id} uses aggregate {function} with mixed runtime input types: expected {expected}, found {actual}"
1152 )]
1153 AggregateInputTypeMismatch {
1154 rule_id: RuleId,
1155 function: AggregateFunction,
1156 expected: String,
1157 actual: String,
1158 },
1159 #[error("aggregate comparison requires comparable values, found {left} and {right}")]
1160 NonComparableAggregateValues { left: String, right: String },
1161}
1162
1163#[cfg(test)]
1164mod tests {
1165 use super::{execute_query, RuleRuntime, RuntimeError, SemiNaiveRuntime};
1166 use aether_ast::{
1167 AggregateFunction, AggregateTerm, Atom, AttributeId, Datom, DatomProvenance, ElementId,
1168 EntityId, ExtensionalFact, Literal, PredicateId, PredicateRef, QueryAst, QueryRow, RuleAst,
1169 RuleId, RuleProgram, Term, Value, Variable,
1170 };
1171 use aether_resolver::{MaterializedResolver, Resolver};
1172 use aether_rules::{DefaultRuleCompiler, RuleCompiler};
1173 use aether_schema::{AttributeClass, AttributeSchema, PredicateSignature, Schema, ValueType};
1174
1175 fn predicate(id: u64, name: &str, arity: usize) -> PredicateRef {
1176 PredicateRef {
1177 id: PredicateId::new(id),
1178 name: name.into(),
1179 arity,
1180 }
1181 }
1182
1183 fn atom(predicate: PredicateRef, vars: &[&str]) -> Atom {
1184 Atom {
1185 predicate,
1186 terms: vars
1187 .iter()
1188 .map(|name| Term::Variable(Variable::new(*name)))
1189 .collect(),
1190 }
1191 }
1192
1193 fn aggregate(function: AggregateFunction, variable: &str) -> Term {
1194 Term::Aggregate(AggregateTerm {
1195 function,
1196 variable: Variable::new(variable),
1197 })
1198 }
1199
1200 fn dependency_datom(entity: u64, value: u64, element: u64) -> Datom {
1201 Datom {
1202 entity: EntityId::new(entity),
1203 attribute: AttributeId::new(1),
1204 value: Value::Entity(EntityId::new(value)),
1205 op: aether_ast::OperationKind::Add,
1206 element: ElementId::new(element),
1207 replica: aether_ast::ReplicaId::new(1),
1208 causal_context: Default::default(),
1209 provenance: DatomProvenance::default(),
1210 policy: None,
1211 }
1212 }
1213
1214 #[test]
1215 fn monotone_transitive_closure_converges_with_iteration_metadata() {
1216 let mut schema = Schema::new("v1");
1217 schema
1218 .register_attribute(AttributeSchema {
1219 id: AttributeId::new(1),
1220 name: "task.depends_on".into(),
1221 class: AttributeClass::RefSet,
1222 value_type: ValueType::Entity,
1223 })
1224 .expect("register attribute");
1225 schema
1226 .register_predicate(PredicateSignature {
1227 id: PredicateId::new(1),
1228 name: "task_depends_on".into(),
1229 fields: vec![ValueType::Entity, ValueType::Entity],
1230 })
1231 .expect("register extensional predicate");
1232 schema
1233 .register_predicate(PredicateSignature {
1234 id: PredicateId::new(2),
1235 name: "depends_transitive".into(),
1236 fields: vec![ValueType::Entity, ValueType::Entity],
1237 })
1238 .expect("register recursive predicate");
1239
1240 let program = RuleProgram {
1241 predicates: vec![
1242 predicate(1, "task_depends_on", 2),
1243 predicate(2, "depends_transitive", 2),
1244 ],
1245 rules: vec![
1246 RuleAst {
1247 id: RuleId::new(1),
1248 head: atom(predicate(2, "depends_transitive", 2), &["x", "y"]),
1249 body: vec![Literal::Positive(atom(
1250 predicate(1, "task_depends_on", 2),
1251 &["x", "y"],
1252 ))],
1253 },
1254 RuleAst {
1255 id: RuleId::new(2),
1256 head: atom(predicate(2, "depends_transitive", 2), &["x", "z"]),
1257 body: vec![
1258 Literal::Positive(atom(predicate(2, "depends_transitive", 2), &["x", "y"])),
1259 Literal::Positive(atom(predicate(1, "task_depends_on", 2), &["y", "z"])),
1260 ],
1261 },
1262 ],
1263 materialized: vec![PredicateId::new(2)],
1264 facts: Vec::new(),
1265 };
1266 let datoms = vec![
1267 dependency_datom(1, 2, 1),
1268 dependency_datom(2, 3, 2),
1269 dependency_datom(3, 4, 3),
1270 ];
1271 let state = MaterializedResolver
1272 .current(&schema, &datoms)
1273 .expect("resolve current state");
1274 let compiled = DefaultRuleCompiler
1275 .compile(&schema, &program)
1276 .expect("compile recursive program");
1277
1278 let derived = SemiNaiveRuntime
1279 .evaluate(&state, &compiled)
1280 .expect("evaluate recursive closure");
1281
1282 let mut pairs = derived
1283 .tuples
1284 .iter()
1285 .map(|tuple| {
1286 let [Value::Entity(left), Value::Entity(right)] = &tuple.tuple.values[..] else {
1287 panic!("expected binary entity tuple");
1288 };
1289 (left.0, right.0)
1290 })
1291 .collect::<Vec<_>>();
1292 pairs.sort_unstable();
1293
1294 assert_eq!(pairs, vec![(1, 2), (1, 3), (1, 4), (2, 3), (2, 4), (3, 4)]);
1295 assert_eq!(
1296 derived
1297 .iterations
1298 .iter()
1299 .map(|iteration| iteration.delta_size)
1300 .collect::<Vec<_>>(),
1301 vec![3, 2, 1, 0]
1302 );
1303 let longest_path = derived
1304 .tuples
1305 .iter()
1306 .find(|tuple| {
1307 tuple.tuple.values
1308 == vec![
1309 Value::Entity(EntityId::new(1)),
1310 Value::Entity(EntityId::new(4)),
1311 ]
1312 })
1313 .expect("longest-path tuple");
1314 assert_eq!(
1315 longest_path.metadata.source_datom_ids,
1316 vec![ElementId::new(1), ElementId::new(2), ElementId::new(3)]
1317 );
1318 }
1319
1320 #[test]
1321 fn bounded_aggregation_materializes_counts_sums_and_maxima() {
1322 let mut schema = Schema::new("v1");
1323 for signature in [
1324 PredicateSignature {
1325 id: PredicateId::new(1),
1326 name: "edge".into(),
1327 fields: vec![ValueType::Entity, ValueType::Entity],
1328 },
1329 PredicateSignature {
1330 id: PredicateId::new(2),
1331 name: "reach".into(),
1332 fields: vec![ValueType::Entity, ValueType::Entity],
1333 },
1334 PredicateSignature {
1335 id: PredicateId::new(3),
1336 name: "reachable_count".into(),
1337 fields: vec![ValueType::Entity, ValueType::U64],
1338 },
1339 PredicateSignature {
1340 id: PredicateId::new(4),
1341 name: "project_task".into(),
1342 fields: vec![ValueType::Entity, ValueType::Entity],
1343 },
1344 PredicateSignature {
1345 id: PredicateId::new(5),
1346 name: "task_hours".into(),
1347 fields: vec![ValueType::Entity, ValueType::U64],
1348 },
1349 PredicateSignature {
1350 id: PredicateId::new(6),
1351 name: "project_hours".into(),
1352 fields: vec![ValueType::Entity, ValueType::U64],
1353 },
1354 PredicateSignature {
1355 id: PredicateId::new(9),
1356 name: "project_stats".into(),
1357 fields: vec![ValueType::Entity, ValueType::U64, ValueType::U64],
1358 },
1359 PredicateSignature {
1360 id: PredicateId::new(7),
1361 name: "execution_attempt".into(),
1362 fields: vec![ValueType::Entity, ValueType::String, ValueType::U64],
1363 },
1364 PredicateSignature {
1365 id: PredicateId::new(8),
1366 name: "latest_epoch".into(),
1367 fields: vec![ValueType::Entity, ValueType::U64],
1368 },
1369 ] {
1370 schema
1371 .register_predicate(signature)
1372 .expect("register predicate");
1373 }
1374
1375 let program = RuleProgram {
1376 predicates: vec![
1377 predicate(1, "edge", 2),
1378 predicate(2, "reach", 2),
1379 predicate(3, "reachable_count", 2),
1380 predicate(4, "project_task", 2),
1381 predicate(5, "task_hours", 2),
1382 predicate(6, "project_hours", 2),
1383 predicate(9, "project_stats", 3),
1384 predicate(7, "execution_attempt", 3),
1385 predicate(8, "latest_epoch", 2),
1386 ],
1387 rules: vec![
1388 RuleAst {
1389 id: RuleId::new(1),
1390 head: atom(predicate(2, "reach", 2), &["x", "y"]),
1391 body: vec![Literal::Positive(atom(
1392 predicate(1, "edge", 2),
1393 &["x", "y"],
1394 ))],
1395 },
1396 RuleAst {
1397 id: RuleId::new(2),
1398 head: atom(predicate(2, "reach", 2), &["x", "z"]),
1399 body: vec![
1400 Literal::Positive(atom(predicate(2, "reach", 2), &["x", "y"])),
1401 Literal::Positive(atom(predicate(1, "edge", 2), &["y", "z"])),
1402 ],
1403 },
1404 RuleAst {
1405 id: RuleId::new(3),
1406 head: Atom {
1407 predicate: predicate(3, "reachable_count", 2),
1408 terms: vec![
1409 Term::Variable(Variable::new("x")),
1410 aggregate(AggregateFunction::Count, "y"),
1411 ],
1412 },
1413 body: vec![Literal::Positive(atom(
1414 predicate(2, "reach", 2),
1415 &["x", "y"],
1416 ))],
1417 },
1418 RuleAst {
1419 id: RuleId::new(4),
1420 head: Atom {
1421 predicate: predicate(6, "project_hours", 2),
1422 terms: vec![
1423 Term::Variable(Variable::new("project")),
1424 aggregate(AggregateFunction::Sum, "hours"),
1425 ],
1426 },
1427 body: vec![
1428 Literal::Positive(atom(
1429 predicate(4, "project_task", 2),
1430 &["project", "task"],
1431 )),
1432 Literal::Positive(atom(predicate(5, "task_hours", 2), &["task", "hours"])),
1433 ],
1434 },
1435 RuleAst {
1436 id: RuleId::new(5),
1437 head: Atom {
1438 predicate: predicate(9, "project_stats", 3),
1439 terms: vec![
1440 Term::Variable(Variable::new("project")),
1441 aggregate(AggregateFunction::Count, "task"),
1442 aggregate(AggregateFunction::Sum, "hours"),
1443 ],
1444 },
1445 body: vec![
1446 Literal::Positive(atom(
1447 predicate(4, "project_task", 2),
1448 &["project", "task"],
1449 )),
1450 Literal::Positive(atom(predicate(5, "task_hours", 2), &["task", "hours"])),
1451 ],
1452 },
1453 RuleAst {
1454 id: RuleId::new(6),
1455 head: Atom {
1456 predicate: predicate(8, "latest_epoch", 2),
1457 terms: vec![
1458 Term::Variable(Variable::new("task")),
1459 aggregate(AggregateFunction::Max, "epoch"),
1460 ],
1461 },
1462 body: vec![Literal::Positive(atom(
1463 predicate(7, "execution_attempt", 3),
1464 &["task", "worker", "epoch"],
1465 ))],
1466 },
1467 ],
1468 materialized: vec![
1469 PredicateId::new(2),
1470 PredicateId::new(3),
1471 PredicateId::new(6),
1472 PredicateId::new(9),
1473 PredicateId::new(8),
1474 ],
1475 facts: vec![
1476 ExtensionalFact {
1477 predicate: predicate(1, "edge", 2),
1478 values: vec![
1479 Value::Entity(EntityId::new(1)),
1480 Value::Entity(EntityId::new(2)),
1481 ],
1482 policy: None,
1483 provenance: None,
1484 },
1485 ExtensionalFact {
1486 predicate: predicate(1, "edge", 2),
1487 values: vec![
1488 Value::Entity(EntityId::new(2)),
1489 Value::Entity(EntityId::new(3)),
1490 ],
1491 policy: None,
1492 provenance: None,
1493 },
1494 ExtensionalFact {
1495 predicate: predicate(1, "edge", 2),
1496 values: vec![
1497 Value::Entity(EntityId::new(3)),
1498 Value::Entity(EntityId::new(4)),
1499 ],
1500 policy: None,
1501 provenance: None,
1502 },
1503 ExtensionalFact {
1504 predicate: predicate(4, "project_task", 2),
1505 values: vec![
1506 Value::Entity(EntityId::new(10)),
1507 Value::Entity(EntityId::new(101)),
1508 ],
1509 policy: None,
1510 provenance: None,
1511 },
1512 ExtensionalFact {
1513 predicate: predicate(4, "project_task", 2),
1514 values: vec![
1515 Value::Entity(EntityId::new(10)),
1516 Value::Entity(EntityId::new(102)),
1517 ],
1518 policy: None,
1519 provenance: None,
1520 },
1521 ExtensionalFact {
1522 predicate: predicate(5, "task_hours", 2),
1523 values: vec![Value::Entity(EntityId::new(101)), Value::U64(3)],
1524 policy: None,
1525 provenance: None,
1526 },
1527 ExtensionalFact {
1528 predicate: predicate(5, "task_hours", 2),
1529 values: vec![Value::Entity(EntityId::new(102)), Value::U64(5)],
1530 policy: None,
1531 provenance: None,
1532 },
1533 ExtensionalFact {
1534 predicate: predicate(7, "execution_attempt", 3),
1535 values: vec![
1536 Value::Entity(EntityId::new(1)),
1537 Value::String("worker-a".into()),
1538 Value::U64(1),
1539 ],
1540 policy: None,
1541 provenance: None,
1542 },
1543 ExtensionalFact {
1544 predicate: predicate(7, "execution_attempt", 3),
1545 values: vec![
1546 Value::Entity(EntityId::new(1)),
1547 Value::String("worker-b".into()),
1548 Value::U64(4),
1549 ],
1550 policy: None,
1551 provenance: None,
1552 },
1553 ],
1554 };
1555
1556 let compiled = DefaultRuleCompiler
1557 .compile(&schema, &program)
1558 .expect("compile aggregate program");
1559 let derived = SemiNaiveRuntime
1560 .evaluate(&Default::default(), &compiled)
1561 .expect("evaluate aggregate program");
1562
1563 let reachable_count = execute_query(
1564 &Default::default(),
1565 &compiled,
1566 &derived,
1567 &QueryAst {
1568 goals: vec![atom(predicate(3, "reachable_count", 2), &["x", "count"])],
1569 keep: vec![Variable::new("x"), Variable::new("count")],
1570 },
1571 None,
1572 )
1573 .expect("query reachable count");
1574 assert_eq!(
1575 reachable_count.rows,
1576 vec![
1577 QueryRow {
1578 values: vec![Value::Entity(EntityId::new(1)), Value::U64(3)],
1579 tuple_id: reachable_count.rows[0].tuple_id,
1580 },
1581 QueryRow {
1582 values: vec![Value::Entity(EntityId::new(2)), Value::U64(2)],
1583 tuple_id: reachable_count.rows[1].tuple_id,
1584 },
1585 QueryRow {
1586 values: vec![Value::Entity(EntityId::new(3)), Value::U64(1)],
1587 tuple_id: reachable_count.rows[2].tuple_id,
1588 },
1589 ]
1590 );
1591
1592 let project_hours = execute_query(
1593 &Default::default(),
1594 &compiled,
1595 &derived,
1596 &QueryAst {
1597 goals: vec![atom(
1598 predicate(6, "project_hours", 2),
1599 &["project", "hours"],
1600 )],
1601 keep: vec![Variable::new("project"), Variable::new("hours")],
1602 },
1603 None,
1604 )
1605 .expect("query project hours");
1606 assert_eq!(
1607 project_hours.rows[0].values,
1608 vec![Value::Entity(EntityId::new(10)), Value::U64(8)]
1609 );
1610
1611 let project_stats = execute_query(
1612 &Default::default(),
1613 &compiled,
1614 &derived,
1615 &QueryAst {
1616 goals: vec![atom(
1617 predicate(9, "project_stats", 3),
1618 &["project", "task_count", "hours"],
1619 )],
1620 keep: vec![
1621 Variable::new("project"),
1622 Variable::new("task_count"),
1623 Variable::new("hours"),
1624 ],
1625 },
1626 None,
1627 )
1628 .expect("query project stats");
1629 assert_eq!(
1630 project_stats.rows[0].values,
1631 vec![
1632 Value::Entity(EntityId::new(10)),
1633 Value::U64(2),
1634 Value::U64(8),
1635 ]
1636 );
1637
1638 let latest_epoch = execute_query(
1639 &Default::default(),
1640 &compiled,
1641 &derived,
1642 &QueryAst {
1643 goals: vec![atom(predicate(8, "latest_epoch", 2), &["task", "epoch"])],
1644 keep: vec![Variable::new("task"), Variable::new("epoch")],
1645 },
1646 None,
1647 )
1648 .expect("query latest epoch");
1649 assert_eq!(
1650 latest_epoch.rows[0].values,
1651 vec![Value::Entity(EntityId::new(1)), Value::U64(4)]
1652 );
1653 }
1654
1655 #[test]
1656 fn stratified_negation_supports_readiness_and_stale_rejection() {
1657 let mut schema = Schema::new("v1");
1658 for attribute in [
1659 AttributeSchema {
1660 id: AttributeId::new(1),
1661 name: "task.depends_on".into(),
1662 class: AttributeClass::RefSet,
1663 value_type: ValueType::Entity,
1664 },
1665 AttributeSchema {
1666 id: AttributeId::new(2),
1667 name: "task.status".into(),
1668 class: AttributeClass::ScalarLww,
1669 value_type: ValueType::String,
1670 },
1671 AttributeSchema {
1672 id: AttributeId::new(3),
1673 name: "task.claimed_by".into(),
1674 class: AttributeClass::ScalarLww,
1675 value_type: ValueType::String,
1676 },
1677 AttributeSchema {
1678 id: AttributeId::new(4),
1679 name: "task.lease_epoch".into(),
1680 class: AttributeClass::ScalarLww,
1681 value_type: ValueType::U64,
1682 },
1683 AttributeSchema {
1684 id: AttributeId::new(5),
1685 name: "task.lease_state".into(),
1686 class: AttributeClass::ScalarLww,
1687 value_type: ValueType::String,
1688 },
1689 ] {
1690 schema
1691 .register_attribute(attribute)
1692 .expect("register attribute");
1693 }
1694
1695 for signature in [
1696 PredicateSignature {
1697 id: PredicateId::new(1),
1698 name: "task".into(),
1699 fields: vec![ValueType::Entity],
1700 },
1701 PredicateSignature {
1702 id: PredicateId::new(2),
1703 name: "execution_attempt".into(),
1704 fields: vec![ValueType::Entity, ValueType::String, ValueType::U64],
1705 },
1706 PredicateSignature {
1707 id: PredicateId::new(3),
1708 name: "task_depends_on".into(),
1709 fields: vec![ValueType::Entity, ValueType::Entity],
1710 },
1711 PredicateSignature {
1712 id: PredicateId::new(4),
1713 name: "task_status".into(),
1714 fields: vec![ValueType::Entity, ValueType::String],
1715 },
1716 PredicateSignature {
1717 id: PredicateId::new(5),
1718 name: "task_claimed_by".into(),
1719 fields: vec![ValueType::Entity, ValueType::String],
1720 },
1721 PredicateSignature {
1722 id: PredicateId::new(6),
1723 name: "task_lease_epoch".into(),
1724 fields: vec![ValueType::Entity, ValueType::U64],
1725 },
1726 PredicateSignature {
1727 id: PredicateId::new(7),
1728 name: "task_lease_state".into(),
1729 fields: vec![ValueType::Entity, ValueType::String],
1730 },
1731 PredicateSignature {
1732 id: PredicateId::new(8),
1733 name: "task_complete".into(),
1734 fields: vec![ValueType::Entity],
1735 },
1736 PredicateSignature {
1737 id: PredicateId::new(9),
1738 name: "dependency_blocked".into(),
1739 fields: vec![ValueType::Entity],
1740 },
1741 PredicateSignature {
1742 id: PredicateId::new(10),
1743 name: "lease_active".into(),
1744 fields: vec![ValueType::Entity, ValueType::String, ValueType::U64],
1745 },
1746 PredicateSignature {
1747 id: PredicateId::new(11),
1748 name: "active_claim".into(),
1749 fields: vec![ValueType::Entity],
1750 },
1751 PredicateSignature {
1752 id: PredicateId::new(12),
1753 name: "task_ready".into(),
1754 fields: vec![ValueType::Entity],
1755 },
1756 PredicateSignature {
1757 id: PredicateId::new(13),
1758 name: "execution_rejected_stale".into(),
1759 fields: vec![ValueType::Entity, ValueType::String, ValueType::U64],
1760 },
1761 ] {
1762 schema
1763 .register_predicate(signature)
1764 .expect("register predicate");
1765 }
1766
1767 let program = RuleProgram {
1768 predicates: vec![
1769 predicate(1, "task", 1),
1770 predicate(2, "execution_attempt", 3),
1771 predicate(3, "task_depends_on", 2),
1772 predicate(4, "task_status", 2),
1773 predicate(5, "task_claimed_by", 2),
1774 predicate(6, "task_lease_epoch", 2),
1775 predicate(7, "task_lease_state", 2),
1776 predicate(8, "task_complete", 1),
1777 predicate(9, "dependency_blocked", 1),
1778 predicate(10, "lease_active", 3),
1779 predicate(11, "active_claim", 1),
1780 predicate(12, "task_ready", 1),
1781 predicate(13, "execution_rejected_stale", 3),
1782 ],
1783 rules: vec![
1784 RuleAst {
1785 id: RuleId::new(1),
1786 head: atom(predicate(8, "task_complete", 1), &["t"]),
1787 body: vec![Literal::Positive(Atom {
1788 predicate: predicate(4, "task_status", 2),
1789 terms: vec![
1790 Term::Variable(Variable::new("t")),
1791 Term::Value(Value::String("done".into())),
1792 ],
1793 })],
1794 },
1795 RuleAst {
1796 id: RuleId::new(2),
1797 head: atom(predicate(9, "dependency_blocked", 1), &["t"]),
1798 body: vec![
1799 Literal::Positive(atom(predicate(3, "task_depends_on", 2), &["t", "dep"])),
1800 Literal::Negative(atom(predicate(8, "task_complete", 1), &["dep"])),
1801 ],
1802 },
1803 RuleAst {
1804 id: RuleId::new(3),
1805 head: atom(predicate(10, "lease_active", 3), &["t", "worker", "epoch"]),
1806 body: vec![
1807 Literal::Positive(atom(
1808 predicate(5, "task_claimed_by", 2),
1809 &["t", "worker"],
1810 )),
1811 Literal::Positive(atom(
1812 predicate(6, "task_lease_epoch", 2),
1813 &["t", "epoch"],
1814 )),
1815 Literal::Positive(Atom {
1816 predicate: predicate(7, "task_lease_state", 2),
1817 terms: vec![
1818 Term::Variable(Variable::new("t")),
1819 Term::Value(Value::String("active".into())),
1820 ],
1821 }),
1822 ],
1823 },
1824 RuleAst {
1825 id: RuleId::new(4),
1826 head: atom(predicate(11, "active_claim", 1), &["t"]),
1827 body: vec![Literal::Positive(atom(
1828 predicate(10, "lease_active", 3),
1829 &["t", "worker", "epoch"],
1830 ))],
1831 },
1832 RuleAst {
1833 id: RuleId::new(5),
1834 head: atom(predicate(12, "task_ready", 1), &["t"]),
1835 body: vec![
1836 Literal::Positive(atom(predicate(1, "task", 1), &["t"])),
1837 Literal::Negative(atom(predicate(9, "dependency_blocked", 1), &["t"])),
1838 Literal::Negative(atom(predicate(11, "active_claim", 1), &["t"])),
1839 ],
1840 },
1841 RuleAst {
1842 id: RuleId::new(6),
1843 head: atom(
1844 predicate(13, "execution_rejected_stale", 3),
1845 &["t", "worker", "epoch"],
1846 ),
1847 body: vec![
1848 Literal::Positive(atom(
1849 predicate(2, "execution_attempt", 3),
1850 &["t", "worker", "epoch"],
1851 )),
1852 Literal::Negative(atom(
1853 predicate(10, "lease_active", 3),
1854 &["t", "worker", "epoch"],
1855 )),
1856 ],
1857 },
1858 ],
1859 materialized: vec![PredicateId::new(12), PredicateId::new(13)],
1860 facts: vec![
1861 ExtensionalFact {
1862 predicate: predicate(1, "task", 1),
1863 values: vec![Value::Entity(EntityId::new(1))],
1864 policy: None,
1865 provenance: None,
1866 },
1867 ExtensionalFact {
1868 predicate: predicate(1, "task", 1),
1869 values: vec![Value::Entity(EntityId::new(2))],
1870 policy: None,
1871 provenance: None,
1872 },
1873 ExtensionalFact {
1874 predicate: predicate(2, "execution_attempt", 3),
1875 values: vec![
1876 Value::Entity(EntityId::new(1)),
1877 Value::String("worker-a".into()),
1878 Value::U64(1),
1879 ],
1880 policy: None,
1881 provenance: None,
1882 },
1883 ],
1884 };
1885 let datoms = vec![
1886 dependency_datom(1, 2, 1),
1887 datom(2, 2, Value::String("done".into()), 2),
1888 datom(1, 3, Value::String("worker-a".into()), 3),
1889 datom(1, 4, Value::U64(1), 4),
1890 datom(1, 5, Value::String("active".into()), 5),
1891 datom(1, 5, Value::String("expired".into()), 6),
1892 ];
1893
1894 let compiled = DefaultRuleCompiler
1895 .compile(&schema, &program)
1896 .expect("compile coordination program");
1897 let as_of_state = MaterializedResolver
1898 .as_of(&schema, &datoms, &ElementId::new(5))
1899 .expect("resolve as_of");
1900 let current_state = MaterializedResolver
1901 .current(&schema, &datoms)
1902 .expect("resolve current");
1903 let as_of_derived = SemiNaiveRuntime
1904 .evaluate(&as_of_state, &compiled)
1905 .expect("evaluate as_of");
1906 let current_derived = SemiNaiveRuntime
1907 .evaluate(¤t_state, &compiled)
1908 .expect("evaluate current");
1909
1910 let as_of_ready = execute_query(
1911 &as_of_state,
1912 &compiled,
1913 &as_of_derived,
1914 &QueryAst {
1915 goals: vec![
1916 atom(predicate(12, "task_ready", 1), &["t"]),
1917 Atom {
1918 predicate: predicate(5, "task_claimed_by", 2),
1919 terms: vec![
1920 Term::Variable(Variable::new("t")),
1921 Term::Value(Value::String("worker-a".into())),
1922 ],
1923 },
1924 ],
1925 keep: vec![Variable::new("t")],
1926 },
1927 None,
1928 )
1929 .expect("query as_of ready");
1930 assert!(as_of_ready.rows.is_empty());
1931
1932 let current_ready = execute_query(
1933 ¤t_state,
1934 &compiled,
1935 ¤t_derived,
1936 &QueryAst {
1937 goals: vec![
1938 atom(predicate(12, "task_ready", 1), &["t"]),
1939 Atom {
1940 predicate: predicate(5, "task_claimed_by", 2),
1941 terms: vec![
1942 Term::Variable(Variable::new("t")),
1943 Term::Value(Value::String("worker-a".into())),
1944 ],
1945 },
1946 ],
1947 keep: vec![Variable::new("t")],
1948 },
1949 None,
1950 )
1951 .expect("query current ready");
1952 assert_eq!(current_ready.rows.len(), 1);
1953 assert_eq!(
1954 current_ready.rows[0].values,
1955 vec![Value::Entity(EntityId::new(1))]
1956 );
1957
1958 let stale_attempts = execute_query(
1959 ¤t_state,
1960 &compiled,
1961 ¤t_derived,
1962 &QueryAst {
1963 goals: vec![atom(
1964 predicate(13, "execution_rejected_stale", 3),
1965 &["t", "worker", "epoch"],
1966 )],
1967 keep: vec![
1968 Variable::new("t"),
1969 Variable::new("worker"),
1970 Variable::new("epoch"),
1971 ],
1972 },
1973 None,
1974 )
1975 .expect("query stale attempts");
1976 assert_eq!(
1977 stale_attempts.rows,
1978 vec![QueryRow {
1979 values: vec![
1980 Value::Entity(EntityId::new(1)),
1981 Value::String("worker-a".into()),
1982 Value::U64(1),
1983 ],
1984 tuple_id: stale_attempts.rows.first().and_then(|row| row.tuple_id),
1985 }]
1986 );
1987 }
1988
1989 #[test]
1990 fn missing_extensional_binding_is_reported() {
1991 let mut schema = Schema::new("v1");
1992 schema
1993 .register_predicate(PredicateSignature {
1994 id: PredicateId::new(10),
1995 name: "edge".into(),
1996 fields: vec![ValueType::Entity, ValueType::Entity],
1997 })
1998 .expect("register edge");
1999 schema
2000 .register_predicate(PredicateSignature {
2001 id: PredicateId::new(11),
2002 name: "reach".into(),
2003 fields: vec![ValueType::Entity, ValueType::Entity],
2004 })
2005 .expect("register reach");
2006 let program = RuleProgram {
2007 predicates: vec![predicate(10, "edge", 2), predicate(11, "reach", 2)],
2008 rules: vec![RuleAst {
2009 id: RuleId::new(1),
2010 head: atom(predicate(11, "reach", 2), &["x", "y"]),
2011 body: vec![Literal::Positive(atom(
2012 predicate(10, "edge", 2),
2013 &["x", "y"],
2014 ))],
2015 }],
2016 materialized: vec![PredicateId::new(11)],
2017 facts: Vec::new(),
2018 };
2019 let compiled = DefaultRuleCompiler
2020 .compile(&schema, &program)
2021 .expect("compile unbound program");
2022
2023 let error = SemiNaiveRuntime
2024 .evaluate(&Default::default(), &compiled)
2025 .expect_err("missing extensional binding should fail");
2026 assert!(matches!(
2027 error,
2028 RuntimeError::MissingExtensionalBinding(id) if id == PredicateId::new(10)
2029 ));
2030 }
2031
2032 fn datom(entity: u64, attribute: u64, value: Value, element: u64) -> Datom {
2033 Datom {
2034 entity: EntityId::new(entity),
2035 attribute: AttributeId::new(attribute),
2036 value,
2037 op: aether_ast::OperationKind::Assert,
2038 element: ElementId::new(element),
2039 replica: aether_ast::ReplicaId::new(1),
2040 causal_context: Default::default(),
2041 provenance: DatomProvenance::default(),
2042 policy: None,
2043 }
2044 }
2045}