1mod parser;
2
3pub use parser::{DefaultDslParser, DslDocument, DslParser, ParseError};
4
5use aether_ast::{
6 AggregateFunction, AggregateTerm, Atom, AttributeId, ExtensionalFact, Literal, PredicateId,
7 RuleAst, RuleId, RuleProgram, Term, Value, Variable,
8};
9use aether_plan::{CompiledProgram, DeltaRulePlan, DependencyGraph, StronglyConnectedComponent};
10use aether_schema::{AttributeSchema, Schema, SchemaError, ValueType};
11use indexmap::{IndexMap, IndexSet};
12use thiserror::Error;
13
14pub trait RuleCompiler {
15 fn compile(
16 &self,
17 schema: &Schema,
18 program: &RuleProgram,
19 ) -> Result<CompiledProgram, CompileError>;
20}
21
22#[derive(Default)]
23pub struct DefaultRuleCompiler;
24
25impl RuleCompiler for DefaultRuleCompiler {
26 fn compile(
27 &self,
28 schema: &Schema,
29 program: &RuleProgram,
30 ) -> Result<CompiledProgram, CompileError> {
31 let mut dependency_graph = DependencyGraph::default();
32 let mut all_predicates = IndexSet::new();
33 let mut negative_edges = Vec::new();
34 let mut delta_plans = Vec::new();
35
36 for predicate in &program.predicates {
37 schema.validate_predicate_arity(&predicate.id, predicate.arity)?;
38 all_predicates.insert(predicate.id);
39 }
40
41 for fact in &program.facts {
42 validate_fact(schema, fact)?;
43 all_predicates.insert(fact.predicate.id);
44 }
45
46 for rule in &program.rules {
47 validate_atom(schema, &rule.head)?;
48 all_predicates.insert(rule.head.predicate.id);
49
50 let positive_variables = positive_variables(rule);
51 validate_rule_safety(rule, &positive_variables)?;
52 validate_rule_types_and_aggregates(schema, rule, &positive_variables)?;
53
54 let mut source_predicates = Vec::new();
55 for literal in &rule.body {
56 let atom = literal_atom(literal);
57 validate_atom(schema, atom)?;
58 all_predicates.insert(atom.predicate.id);
59
60 match literal {
61 Literal::Positive(atom) => {
62 dependency_graph.add_edge(rule.head.predicate.id, atom.predicate.id);
63 source_predicates.push(atom.predicate.id);
64 }
65 Literal::Negative(atom) => {
66 negative_edges.push((rule.head.predicate.id, atom.predicate.id));
67 }
68 }
69 }
70
71 delta_plans.push(DeltaRulePlan {
72 rule_id: rule.id,
73 target_predicate: rule.head.predicate.id,
74 source_predicates,
75 });
76 }
77
78 for predicate in &all_predicates {
79 dependency_graph.edges.entry(*predicate).or_default();
80 }
81
82 let sccs = compute_sccs(&dependency_graph, &all_predicates);
83 let scc_lookup = build_scc_lookup(&sccs);
84 validate_recursive_aggregation(schema, program, &dependency_graph, &sccs, &scc_lookup)?;
85 for (head, dependency) in &negative_edges {
86 if scc_lookup.get(head) == scc_lookup.get(dependency) {
87 return Err(CompileError::UnstratifiedNegation {
88 depender: predicate_label(schema, *head),
89 dependency: predicate_label(schema, *dependency),
90 });
91 }
92 }
93 let predicate_strata =
94 compute_predicate_strata(schema, &dependency_graph, &scc_lookup, &negative_edges)?;
95
96 let phase_graph = build_phase_graph(schema, &dependency_graph, &sccs, &scc_lookup);
97 let extensional_bindings = infer_extensional_bindings(schema, program)?;
98
99 Ok(CompiledProgram {
100 dependency_graph,
101 sccs,
102 phase_graph,
103 delta_plans,
104 materialized: program.materialized.clone(),
105 rules: program.rules.clone(),
106 extensional_bindings,
107 facts: program.facts.clone(),
108 predicate_strata,
109 })
110 }
111}
112
113fn validate_atom(schema: &Schema, atom: &Atom) -> Result<(), CompileError> {
114 schema.validate_predicate_arity(&atom.predicate.id, atom.terms.len())?;
115 Ok(())
116}
117
118fn validate_fact(schema: &Schema, fact: &ExtensionalFact) -> Result<(), CompileError> {
119 schema.validate_predicate_arity(&fact.predicate.id, fact.values.len())?;
120 let signature = schema
121 .predicate(&fact.predicate.id)
122 .ok_or(SchemaError::UnknownPredicate(fact.predicate.id))?;
123 for (value, expected) in fact.values.iter().zip(&signature.fields) {
124 if !value_matches_type(value, expected) {
125 return Err(CompileError::FactTypeMismatch {
126 predicate: fact.predicate.name.clone(),
127 expected: signature.fields.clone(),
128 actual: fact.values.iter().map(value_type_of).collect(),
129 });
130 }
131 }
132 Ok(())
133}
134
135fn positive_variables(rule: &RuleAst) -> IndexSet<Variable> {
136 let mut variables = IndexSet::new();
137 for literal in &rule.body {
138 if let Literal::Positive(atom) = literal {
139 variables.extend(atom_variables(atom));
140 }
141 }
142 variables
143}
144
145fn validate_rule_safety(
146 rule: &RuleAst,
147 positive_variables: &IndexSet<Variable>,
148) -> Result<(), CompileError> {
149 for variable in atom_variables(&rule.head) {
150 if !positive_variables.contains(&variable) {
151 return Err(CompileError::UnsafeVariable {
152 rule_id: rule.id,
153 variable: variable.0,
154 });
155 }
156 }
157
158 for literal in &rule.body {
159 if let Literal::Negative(atom) = literal {
160 for variable in atom_variables(atom) {
161 if !positive_variables.contains(&variable) {
162 return Err(CompileError::UnsafeVariable {
163 rule_id: rule.id,
164 variable: variable.0,
165 });
166 }
167 }
168 }
169 }
170
171 Ok(())
172}
173
174fn atom_variables(atom: &Atom) -> IndexSet<Variable> {
175 atom.terms
176 .iter()
177 .filter_map(|term| match term {
178 Term::Variable(variable) => Some(variable.clone()),
179 Term::Aggregate(aggregate) => Some(aggregate.variable.clone()),
180 Term::Value(_) => None,
181 })
182 .collect()
183}
184
185fn literal_atom(literal: &Literal) -> &Atom {
186 match literal {
187 Literal::Positive(atom) | Literal::Negative(atom) => atom,
188 }
189}
190
191fn validate_rule_types_and_aggregates(
192 schema: &Schema,
193 rule: &RuleAst,
194 positive_variables: &IndexSet<Variable>,
195) -> Result<(), CompileError> {
196 let variable_types = infer_rule_variable_types(schema, rule)?;
197 let signature = schema
198 .predicate(&rule.head.predicate.id)
199 .expect("validated rule head predicate is present in schema");
200 let aggregates = head_aggregates(rule);
201
202 for (_, aggregate_term) in &aggregates {
203 if rule.head.terms.iter().any(
204 |term| matches!(term, Term::Variable(variable) if variable == &aggregate_term.variable),
205 ) {
206 return Err(CompileError::AggregateVariableInGroupKey {
207 rule_id: rule.id,
208 variable: aggregate_term.variable.0.clone(),
209 });
210 }
211 }
212
213 for (position, (term, expected)) in rule.head.terms.iter().zip(&signature.fields).enumerate() {
214 match term {
215 Term::Variable(variable) => {
216 let actual = variable_types.get(variable).ok_or_else(|| {
217 CompileError::RuleVariableTypeUnknown {
218 rule_id: rule.id,
219 variable: variable.0.clone(),
220 }
221 })?;
222 if actual != expected {
223 return Err(CompileError::RuleTermTypeMismatch {
224 rule_id: rule.id,
225 predicate: signature.name.clone(),
226 position,
227 expected: expected.clone(),
228 actual: actual.clone(),
229 });
230 }
231 }
232 Term::Value(value) => {
233 if !value_matches_type(value, expected) {
234 return Err(CompileError::RuleTermTypeMismatch {
235 rule_id: rule.id,
236 predicate: signature.name.clone(),
237 position,
238 expected: expected.clone(),
239 actual: value_type_of(value),
240 });
241 }
242 }
243 Term::Aggregate(aggregate_term) => {
244 if !positive_variables.contains(&aggregate_term.variable) {
245 return Err(CompileError::UnsafeVariable {
246 rule_id: rule.id,
247 variable: aggregate_term.variable.0.clone(),
248 });
249 }
250 let input_type = variable_types
251 .get(&aggregate_term.variable)
252 .ok_or_else(|| CompileError::RuleVariableTypeUnknown {
253 rule_id: rule.id,
254 variable: aggregate_term.variable.0.clone(),
255 })?;
256 let output_type =
257 aggregate_output_type(rule.id, aggregate_term, input_type.clone())?;
258 if &output_type != expected {
259 return Err(CompileError::AggregateOutputTypeMismatch {
260 rule_id: rule.id,
261 function: aggregate_term.function,
262 expected: expected.clone(),
263 actual: output_type,
264 });
265 }
266 }
267 }
268 }
269
270 Ok(())
271}
272
273fn infer_rule_variable_types(
274 schema: &Schema,
275 rule: &RuleAst,
276) -> Result<IndexMap<Variable, ValueType>, CompileError> {
277 let mut variable_types = IndexMap::new();
278 validate_atom_term_types(schema, rule.id, &rule.head, true, &mut variable_types)?;
279 for literal in &rule.body {
280 validate_atom_term_types(
281 schema,
282 rule.id,
283 literal_atom(literal),
284 false,
285 &mut variable_types,
286 )?;
287 }
288 Ok(variable_types)
289}
290
291fn validate_atom_term_types(
292 schema: &Schema,
293 rule_id: RuleId,
294 atom: &Atom,
295 allow_head_aggregate: bool,
296 variable_types: &mut IndexMap<Variable, ValueType>,
297) -> Result<(), CompileError> {
298 let signature = schema
299 .predicate(&atom.predicate.id)
300 .expect("validated atom predicate is present in schema");
301
302 for (position, (term, expected)) in atom.terms.iter().zip(&signature.fields).enumerate() {
303 match term {
304 Term::Variable(variable) => {
305 if let Some(existing) = variable_types.get(variable) {
306 if existing != expected {
307 return Err(CompileError::RuleVariableTypeConflict {
308 rule_id,
309 variable: variable.0.clone(),
310 first: existing.clone(),
311 second: expected.clone(),
312 });
313 }
314 } else {
315 variable_types.insert(variable.clone(), expected.clone());
316 }
317 }
318 Term::Value(value) => {
319 if !value_matches_type(value, expected) {
320 return Err(CompileError::RuleTermTypeMismatch {
321 rule_id,
322 predicate: signature.name.clone(),
323 position,
324 expected: expected.clone(),
325 actual: value_type_of(value),
326 });
327 }
328 }
329 Term::Aggregate(_) if !allow_head_aggregate => {
330 return Err(CompileError::AggregateOutsideHead { rule_id });
331 }
332 Term::Aggregate(_) => {}
333 }
334 }
335
336 Ok(())
337}
338
339fn head_aggregates(rule: &RuleAst) -> Vec<(usize, &AggregateTerm)> {
340 rule.head
341 .terms
342 .iter()
343 .enumerate()
344 .filter_map(|(index, term)| match term {
345 Term::Aggregate(aggregate_term) => Some((index, aggregate_term)),
346 _ => None,
347 })
348 .collect()
349}
350
351fn aggregate_output_type(
352 rule_id: RuleId,
353 aggregate: &AggregateTerm,
354 input_type: ValueType,
355) -> Result<ValueType, CompileError> {
356 match aggregate.function {
357 AggregateFunction::Count => Ok(ValueType::U64),
358 AggregateFunction::Sum => match input_type {
359 ValueType::I64 | ValueType::U64 | ValueType::F64 => Ok(input_type),
360 other => Err(CompileError::UnsupportedAggregateInputType {
361 rule_id,
362 function: aggregate.function,
363 variable: aggregate.variable.0.clone(),
364 input_type: other,
365 }),
366 },
367 AggregateFunction::Min | AggregateFunction::Max => match input_type {
368 ValueType::I64
369 | ValueType::U64
370 | ValueType::F64
371 | ValueType::String
372 | ValueType::Entity => Ok(input_type),
373 other => Err(CompileError::UnsupportedAggregateInputType {
374 rule_id,
375 function: aggregate.function,
376 variable: aggregate.variable.0.clone(),
377 input_type: other,
378 }),
379 },
380 }
381}
382
383fn validate_recursive_aggregation(
384 schema: &Schema,
385 program: &RuleProgram,
386 graph: &DependencyGraph,
387 sccs: &[StronglyConnectedComponent],
388 scc_lookup: &IndexMap<PredicateId, usize>,
389) -> Result<(), CompileError> {
390 for rule in &program.rules {
391 if head_aggregates(rule).is_empty() {
392 continue;
393 }
394
395 let scc_id = *scc_lookup
396 .get(&rule.head.predicate.id)
397 .expect("aggregate head predicate should be present in scc lookup");
398 let recursive = sccs.iter().find(|scc| scc.id == scc_id).is_some_and(|scc| {
399 scc.predicates.len() > 1
400 || scc.predicates.iter().any(|predicate| {
401 graph
402 .edges
403 .get(predicate)
404 .is_some_and(|deps| deps.contains(predicate))
405 })
406 });
407 if recursive {
408 return Err(CompileError::RecursiveAggregation {
409 rule_id: rule.id,
410 predicate: predicate_label(schema, rule.head.predicate.id),
411 });
412 }
413 }
414
415 Ok(())
416}
417
418fn compute_sccs(
419 graph: &DependencyGraph,
420 predicates: &IndexSet<PredicateId>,
421) -> Vec<StronglyConnectedComponent> {
422 let mut visited = IndexSet::new();
423 let mut order = Vec::new();
424
425 for predicate in predicates {
426 dfs_forward(*predicate, graph, &mut visited, &mut order);
427 }
428
429 let reversed = reverse_graph(graph, predicates);
430 visited.clear();
431
432 let mut sccs = Vec::new();
433 let mut next_id = 0usize;
434 while let Some(predicate) = order.pop() {
435 if visited.contains(&predicate) {
436 continue;
437 }
438 let mut component = Vec::new();
439 dfs_reverse(predicate, &reversed, &mut visited, &mut component);
440 component.sort();
441 sccs.push(StronglyConnectedComponent {
442 id: next_id,
443 predicates: component,
444 });
445 next_id += 1;
446 }
447
448 sccs
449}
450
451fn dfs_forward(
452 start: PredicateId,
453 graph: &DependencyGraph,
454 visited: &mut IndexSet<PredicateId>,
455 order: &mut Vec<PredicateId>,
456) {
457 if !visited.insert(start) {
458 return;
459 }
460
461 if let Some(neighbors) = graph.edges.get(&start) {
462 for neighbor in neighbors {
463 dfs_forward(*neighbor, graph, visited, order);
464 }
465 }
466
467 order.push(start);
468}
469
470fn reverse_graph(
471 graph: &DependencyGraph,
472 predicates: &IndexSet<PredicateId>,
473) -> IndexMap<PredicateId, Vec<PredicateId>> {
474 let mut reversed: IndexMap<PredicateId, Vec<PredicateId>> = predicates
475 .iter()
476 .map(|predicate| (*predicate, Vec::new()))
477 .collect();
478
479 for (head, dependencies) in &graph.edges {
480 for dependency in dependencies {
481 reversed.entry(*dependency).or_default().push(*head);
482 }
483 }
484
485 reversed
486}
487
488fn dfs_reverse(
489 start: PredicateId,
490 graph: &IndexMap<PredicateId, Vec<PredicateId>>,
491 visited: &mut IndexSet<PredicateId>,
492 component: &mut Vec<PredicateId>,
493) {
494 if !visited.insert(start) {
495 return;
496 }
497
498 component.push(start);
499 if let Some(neighbors) = graph.get(&start) {
500 for neighbor in neighbors {
501 dfs_reverse(*neighbor, graph, visited, component);
502 }
503 }
504}
505
506fn build_scc_lookup(sccs: &[StronglyConnectedComponent]) -> IndexMap<PredicateId, usize> {
507 let mut lookup = IndexMap::new();
508 for scc in sccs {
509 for predicate in &scc.predicates {
510 lookup.insert(*predicate, scc.id);
511 }
512 }
513 lookup
514}
515
516fn build_phase_graph(
517 schema: &Schema,
518 graph: &DependencyGraph,
519 sccs: &[StronglyConnectedComponent],
520 scc_lookup: &IndexMap<PredicateId, usize>,
521) -> aether_ast::PhaseGraph {
522 let mut nodes = Vec::new();
523 let mut edges = IndexSet::new();
524
525 for scc in sccs {
526 let provides: Vec<String> = scc
527 .predicates
528 .iter()
529 .map(|predicate| predicate_label(schema, *predicate))
530 .collect();
531 let mut available = Vec::new();
532
533 for predicate in &scc.predicates {
534 if let Some(dependencies) = graph.edges.get(predicate) {
535 for dependency in dependencies {
536 let dependency_scc = *scc_lookup
537 .get(dependency)
538 .expect("predicate present in scc lookup");
539 if dependency_scc != scc.id {
540 available.push(predicate_label(schema, *dependency));
541 edges.insert((dependency_scc, scc.id));
542 }
543 }
544 }
545 }
546
547 available.sort();
548 available.dedup();
549
550 let recursive = scc.predicates.len() > 1
551 || scc.predicates.iter().any(|predicate| {
552 graph
553 .edges
554 .get(predicate)
555 .is_some_and(|deps| deps.contains(predicate))
556 });
557
558 nodes.push(aether_ast::PhaseNode {
559 id: format!("scc-{}", scc.id),
560 signature: aether_ast::PhaseSignature {
561 available,
562 provides: provides.clone(),
563 keep: provides,
564 },
565 recursive_scc: recursive.then_some(scc.id),
566 });
567 }
568
569 let edges = edges
570 .into_iter()
571 .map(|(from, to)| aether_ast::PhaseEdge {
572 from: format!("scc-{}", from),
573 to: format!("scc-{}", to),
574 })
575 .collect();
576
577 aether_ast::PhaseGraph { nodes, edges }
578}
579
580fn compute_predicate_strata(
581 _schema: &Schema,
582 graph: &DependencyGraph,
583 scc_lookup: &IndexMap<PredicateId, usize>,
584 negative_edges: &[(PredicateId, PredicateId)],
585) -> Result<IndexMap<PredicateId, usize>, CompileError> {
586 let mut condensed_edges: IndexMap<(usize, usize), usize> = IndexMap::new();
587 let mut scc_ids = IndexSet::new();
588 for scc_id in scc_lookup.values() {
589 scc_ids.insert(*scc_id);
590 }
591
592 for (head, dependencies) in &graph.edges {
593 let to = *scc_lookup
594 .get(head)
595 .expect("head predicate should be present in scc lookup");
596 for dependency in dependencies {
597 let from = *scc_lookup
598 .get(dependency)
599 .expect("dependency predicate should be present in scc lookup");
600 if from != to {
601 scc_ids.insert(from);
602 scc_ids.insert(to);
603 condensed_edges.entry((from, to)).or_insert(0);
604 }
605 }
606 }
607
608 for (head, dependency) in negative_edges {
609 let to = *scc_lookup
610 .get(head)
611 .expect("negative head predicate should be present in scc lookup");
612 let from = *scc_lookup
613 .get(dependency)
614 .expect("negative dependency predicate should be present in scc lookup");
615 if from != to {
616 scc_ids.insert(from);
617 scc_ids.insert(to);
618 condensed_edges
619 .entry((from, to))
620 .and_modify(|weight| *weight = (*weight).max(1))
621 .or_insert(1);
622 }
623 }
624
625 let mut outgoing: IndexMap<usize, Vec<(usize, usize)>> = scc_ids
626 .iter()
627 .copied()
628 .map(|scc_id| (scc_id, Vec::new()))
629 .collect();
630 let mut indegree: IndexMap<usize, usize> = scc_ids
631 .iter()
632 .copied()
633 .map(|scc_id| (scc_id, 0usize))
634 .collect();
635
636 for ((from, to), weight) in condensed_edges {
637 outgoing.entry(from).or_default().push((to, weight));
638 *indegree.entry(to).or_default() += 1;
639 }
640
641 let mut ready = indegree
642 .iter()
643 .filter_map(|(scc_id, degree)| (*degree == 0).then_some(*scc_id))
644 .collect::<Vec<_>>();
645 ready.sort_unstable();
646
647 let mut order = Vec::new();
648 while let Some(scc_id) = ready.first().copied() {
649 ready.remove(0);
650 order.push(scc_id);
651 if let Some(edges) = outgoing.get(&scc_id) {
652 for (to, _) in edges {
653 let degree = indegree
654 .get_mut(to)
655 .expect("target scc should have indegree entry");
656 *degree -= 1;
657 if *degree == 0 {
658 ready.push(*to);
659 ready.sort_unstable();
660 }
661 }
662 }
663 }
664
665 if order.len() != indegree.len() {
666 return Err(CompileError::UnstratifiedNegation {
667 depender: "program".into(),
668 dependency: "negative cycle".into(),
669 });
670 }
671
672 let mut scc_strata: IndexMap<usize, usize> = scc_ids
673 .iter()
674 .copied()
675 .map(|scc_id| (scc_id, 0usize))
676 .collect();
677 for scc_id in order {
678 let current = *scc_strata
679 .get(&scc_id)
680 .expect("source scc should have a stratum");
681 if let Some(edges) = outgoing.get(&scc_id) {
682 for (to, weight) in edges {
683 let target = scc_strata
684 .get_mut(to)
685 .expect("target scc should have a stratum");
686 *target = (*target).max(current + *weight);
687 }
688 }
689 }
690
691 Ok(scc_lookup
692 .iter()
693 .map(|(predicate, scc_id)| {
694 (
695 *predicate,
696 *scc_strata
697 .get(scc_id)
698 .expect("predicate scc should have a stratum"),
699 )
700 })
701 .collect())
702}
703
704fn infer_extensional_bindings(
705 schema: &Schema,
706 program: &RuleProgram,
707) -> Result<IndexMap<PredicateId, AttributeId>, CompileError> {
708 let mut bindings = IndexMap::new();
709
710 for predicate in &program.predicates {
711 if predicate.arity != 2 {
712 continue;
713 }
714
715 if let Some(attribute) = matching_attribute(schema, &predicate.name) {
716 validate_extensional_binding(schema, predicate.id, attribute)?;
717 bindings.insert(predicate.id, attribute.id);
718 }
719 }
720
721 Ok(bindings)
722}
723
724fn matching_attribute<'a>(schema: &'a Schema, predicate_name: &str) -> Option<&'a AttributeSchema> {
725 let mut candidates = vec![predicate_name.to_owned()];
726 if predicate_name.contains('_') {
727 candidates.push(predicate_name.replacen('_', ".", 1));
728 candidates.push(predicate_name.replace('_', "."));
729 }
730
731 candidates.dedup();
732
733 candidates.into_iter().find_map(|candidate| {
734 schema
735 .attributes
736 .values()
737 .find(|attribute| attribute.name == candidate)
738 })
739}
740
741fn validate_extensional_binding(
742 schema: &Schema,
743 predicate: PredicateId,
744 attribute: &AttributeSchema,
745) -> Result<(), CompileError> {
746 let signature = schema
747 .predicate(&predicate)
748 .expect("validated predicates are present in schema");
749 let expected_fields = vec![ValueType::Entity, attribute.value_type.clone()];
750
751 if signature.fields != expected_fields {
752 return Err(CompileError::IncompatibleExtensionalBinding {
753 predicate: signature.name.clone(),
754 attribute: attribute.name.clone(),
755 expected_fields,
756 actual_fields: signature.fields.clone(),
757 });
758 }
759
760 Ok(())
761}
762
763fn predicate_label(schema: &Schema, predicate: PredicateId) -> String {
764 schema
765 .predicate(&predicate)
766 .map(|signature| signature.name.clone())
767 .unwrap_or_else(|| format!("predicate-{}", predicate))
768}
769
770fn value_matches_type(value: &Value, expected: &ValueType) -> bool {
771 match (value, expected) {
772 (Value::Null, _) => true,
773 (Value::Bool(_), ValueType::Bool) => true,
774 (Value::I64(_), ValueType::I64) => true,
775 (Value::U64(_), ValueType::U64) => true,
776 (Value::F64(_), ValueType::F64) => true,
777 (Value::String(_), ValueType::String) => true,
778 (Value::Bytes(_), ValueType::Bytes) => true,
779 (Value::Entity(_), ValueType::Entity) => true,
780 (Value::List(values), ValueType::List(inner)) => {
781 values.iter().all(|value| value_matches_type(value, inner))
782 }
783 _ => false,
784 }
785}
786
787fn value_type_of(value: &Value) -> ValueType {
788 match value {
789 Value::Null => ValueType::String,
790 Value::Bool(_) => ValueType::Bool,
791 Value::I64(_) => ValueType::I64,
792 Value::U64(_) => ValueType::U64,
793 Value::F64(_) => ValueType::F64,
794 Value::String(_) => ValueType::String,
795 Value::Bytes(_) => ValueType::Bytes,
796 Value::Entity(_) => ValueType::Entity,
797 Value::List(values) => ValueType::List(Box::new(
798 values
799 .first()
800 .map(value_type_of)
801 .unwrap_or(ValueType::String),
802 )),
803 }
804}
805
806#[derive(Debug, Error)]
807pub enum CompileError {
808 #[error(transparent)]
809 Schema(#[from] SchemaError),
810 #[error("rule {rule_id} uses aggregate terms outside a rule head")]
811 AggregateOutsideHead { rule_id: RuleId },
812 #[error("rule {rule_id} cannot group by aggregate variable {variable}")]
813 AggregateVariableInGroupKey { rule_id: RuleId, variable: String },
814 #[error(
815 "rule {rule_id} uses aggregate {function} over variable {variable} with unsupported input type {input_type:?}"
816 )]
817 UnsupportedAggregateInputType {
818 rule_id: RuleId,
819 function: AggregateFunction,
820 variable: String,
821 input_type: ValueType,
822 },
823 #[error(
824 "rule {rule_id} produces aggregate {function} with type {actual:?}, but the head expects {expected:?}"
825 )]
826 AggregateOutputTypeMismatch {
827 rule_id: RuleId,
828 function: AggregateFunction,
829 expected: ValueType,
830 actual: ValueType,
831 },
832 #[error(
833 "rule {rule_id} uses variable {variable} with incompatible types {first:?} and {second:?}"
834 )]
835 RuleVariableTypeConflict {
836 rule_id: RuleId,
837 variable: String,
838 first: ValueType,
839 second: ValueType,
840 },
841 #[error("rule {rule_id} references variable {variable}, but its type could not be inferred")]
842 RuleVariableTypeUnknown { rule_id: RuleId, variable: String },
843 #[error(
844 "rule {rule_id} uses term {position} of predicate {predicate} with type {actual:?}, expected {expected:?}"
845 )]
846 RuleTermTypeMismatch {
847 rule_id: RuleId,
848 predicate: String,
849 position: usize,
850 expected: ValueType,
851 actual: ValueType,
852 },
853 #[error("rule {rule_id} uses bounded aggregation recursively through predicate {predicate}")]
854 RecursiveAggregation { rule_id: RuleId, predicate: String },
855 #[error(
856 "predicate {predicate} cannot bind to attribute {attribute}: expected {expected_fields:?}, found {actual_fields:?}"
857 )]
858 IncompatibleExtensionalBinding {
859 predicate: String,
860 attribute: String,
861 expected_fields: Vec<ValueType>,
862 actual_fields: Vec<ValueType>,
863 },
864 #[error(
865 "fact for predicate {predicate} does not match the declared types: expected {expected:?}, found {actual:?}"
866 )]
867 FactTypeMismatch {
868 predicate: String,
869 expected: Vec<ValueType>,
870 actual: Vec<ValueType>,
871 },
872 #[error("rule {rule_id} uses unsafe variable {variable}")]
873 UnsafeVariable {
874 rule_id: aether_ast::RuleId,
875 variable: String,
876 },
877 #[error("unstratified negation detected: {depender} depends negatively on {dependency}")]
878 UnstratifiedNegation {
879 depender: String,
880 dependency: String,
881 },
882}
883
884#[cfg(test)]
885mod tests {
886 use super::{CompileError, DefaultRuleCompiler, RuleCompiler};
887 use aether_ast::{
888 AggregateFunction, AggregateTerm, Atom, AttributeId, ExtensionalFact, Literal, PredicateId,
889 PredicateRef, RuleAst, RuleId, RuleProgram, Term, Value, Variable,
890 };
891 use aether_schema::{AttributeClass, AttributeSchema, PredicateSignature, Schema, ValueType};
892
893 fn predicate(id: u64, name: &str, arity: usize) -> PredicateRef {
894 PredicateRef {
895 id: PredicateId::new(id),
896 name: name.into(),
897 arity,
898 }
899 }
900
901 fn atom(predicate: PredicateRef, vars: &[&str]) -> Atom {
902 Atom {
903 predicate,
904 terms: vars
905 .iter()
906 .map(|name| Term::Variable(Variable::new(*name)))
907 .collect(),
908 }
909 }
910
911 fn aggregate(function: AggregateFunction, variable: &str) -> Term {
912 Term::Aggregate(AggregateTerm {
913 function,
914 variable: Variable::new(variable),
915 })
916 }
917
918 fn schema(predicates: &[(u64, &str, usize)]) -> Schema {
919 let mut schema = Schema::new("v1");
920 for (id, name, arity) in predicates {
921 schema
922 .register_predicate(PredicateSignature {
923 id: PredicateId::new(*id),
924 name: (*name).into(),
925 fields: vec![ValueType::Entity; *arity],
926 })
927 .expect("register predicate");
928 }
929 schema
930 }
931
932 #[test]
933 fn safe_recursive_program_builds_expected_graph_and_phase_boundaries() {
934 let edge = predicate(1, "edge", 2);
935 let reach = predicate(2, "reach", 2);
936 let schema = schema(&[(1, "edge", 2), (2, "reach", 2)]);
937 let program = RuleProgram {
938 predicates: vec![edge.clone(), reach.clone()],
939 rules: vec![
940 RuleAst {
941 id: RuleId::new(1),
942 head: atom(reach.clone(), &["x", "y"]),
943 body: vec![Literal::Positive(atom(edge.clone(), &["x", "y"]))],
944 },
945 RuleAst {
946 id: RuleId::new(2),
947 head: atom(reach.clone(), &["x", "z"]),
948 body: vec![
949 Literal::Positive(atom(reach.clone(), &["x", "y"])),
950 Literal::Positive(atom(edge.clone(), &["y", "z"])),
951 ],
952 },
953 ],
954 materialized: vec![reach.id],
955 facts: Vec::new(),
956 };
957
958 let compiled = DefaultRuleCompiler
959 .compile(&schema, &program)
960 .expect("compile recursive program");
961 let reach_edges = compiled
962 .dependency_graph
963 .edges
964 .get(&reach.id)
965 .expect("reach edges");
966
967 assert!(reach_edges.contains(&edge.id));
968 assert!(reach_edges.contains(&reach.id));
969 assert_eq!(compiled.sccs.len(), 2);
970
971 let reach_scc = compiled
972 .sccs
973 .iter()
974 .find(|scc| scc.predicates.contains(&reach.id))
975 .expect("reach scc");
976 let edge_scc = compiled
977 .sccs
978 .iter()
979 .find(|scc| scc.predicates.contains(&edge.id))
980 .expect("edge scc");
981 let reach_node = compiled
982 .phase_graph
983 .nodes
984 .iter()
985 .find(|node| node.id == format!("scc-{}", reach_scc.id))
986 .expect("reach phase node");
987 let edge_node = compiled
988 .phase_graph
989 .nodes
990 .iter()
991 .find(|node| node.id == format!("scc-{}", edge_scc.id))
992 .expect("edge phase node");
993
994 assert_eq!(reach_node.recursive_scc, Some(reach_scc.id));
995 assert_eq!(edge_node.recursive_scc, None);
996 assert_eq!(compiled.predicate_strata.get(&edge.id).copied(), Some(0));
997 assert_eq!(compiled.predicate_strata.get(&reach.id).copied(), Some(0));
998 assert!(compiled.phase_graph.edges.iter().any(|edge_ref| {
999 edge_ref.from == format!("scc-{}", edge_scc.id)
1000 && edge_ref.to == format!("scc-{}", reach_scc.id)
1001 }));
1002 assert_eq!(compiled.rules, program.rules);
1003 }
1004
1005 #[test]
1006 fn extensional_predicates_bind_to_matching_attribute_names() {
1007 let task_depends_on = predicate(10, "task_depends_on", 2);
1008 let depends_transitive = predicate(11, "depends_transitive", 2);
1009 let mut schema = schema(&[(10, "task_depends_on", 2), (11, "depends_transitive", 2)]);
1010 schema
1011 .register_attribute(AttributeSchema {
1012 id: AttributeId::new(21),
1013 name: "task.depends_on".into(),
1014 class: AttributeClass::RefSet,
1015 value_type: ValueType::Entity,
1016 })
1017 .expect("register attribute");
1018
1019 let compiled = DefaultRuleCompiler
1020 .compile(
1021 &schema,
1022 &RuleProgram {
1023 predicates: vec![task_depends_on.clone(), depends_transitive.clone()],
1024 rules: vec![RuleAst {
1025 id: RuleId::new(1),
1026 head: atom(depends_transitive, &["x", "y"]),
1027 body: vec![Literal::Positive(atom(
1028 task_depends_on.clone(),
1029 &["x", "y"],
1030 ))],
1031 }],
1032 materialized: vec![task_depends_on.id],
1033 facts: Vec::new(),
1034 },
1035 )
1036 .expect("compile program");
1037
1038 assert_eq!(
1039 compiled.extensional_bindings.get(&task_depends_on.id),
1040 Some(&AttributeId::new(21))
1041 );
1042 }
1043
1044 #[test]
1045 fn bounded_aggregation_requires_non_recursive_rules_and_matching_output_types() {
1046 let edge = predicate(1, "edge", 2);
1047 let reach_count = predicate(2, "reach_count", 2);
1048 let mut aggregate_schema = Schema::new("v1");
1049 aggregate_schema
1050 .register_predicate(PredicateSignature {
1051 id: edge.id,
1052 name: edge.name.clone(),
1053 fields: vec![ValueType::Entity, ValueType::Entity],
1054 })
1055 .expect("register edge");
1056 aggregate_schema
1057 .register_predicate(PredicateSignature {
1058 id: reach_count.id,
1059 name: reach_count.name.clone(),
1060 fields: vec![ValueType::Entity, ValueType::String],
1061 })
1062 .expect("register aggregate predicate");
1063
1064 let type_mismatch = DefaultRuleCompiler
1065 .compile(
1066 &aggregate_schema,
1067 &RuleProgram {
1068 predicates: vec![edge.clone(), reach_count.clone()],
1069 rules: vec![RuleAst {
1070 id: RuleId::new(1),
1071 head: Atom {
1072 predicate: reach_count.clone(),
1073 terms: vec![
1074 Term::Variable(Variable::new("x")),
1075 aggregate(AggregateFunction::Count, "y"),
1076 ],
1077 },
1078 body: vec![Literal::Positive(atom(edge.clone(), &["x", "y"]))],
1079 }],
1080 materialized: vec![reach_count.id],
1081 facts: Vec::new(),
1082 },
1083 )
1084 .expect_err("aggregate output type mismatch should fail");
1085 assert!(matches!(
1086 type_mismatch,
1087 CompileError::AggregateOutputTypeMismatch {
1088 rule_id,
1089 function: AggregateFunction::Count,
1090 expected: ValueType::String,
1091 actual: ValueType::U64,
1092 } if rule_id == RuleId::new(1)
1093 ));
1094
1095 let mut recursive_schema = Schema::new("v1");
1096 recursive_schema
1097 .register_predicate(PredicateSignature {
1098 id: PredicateId::new(1),
1099 name: "edge".into(),
1100 fields: vec![ValueType::Entity, ValueType::Entity],
1101 })
1102 .expect("register edge");
1103 recursive_schema
1104 .register_predicate(PredicateSignature {
1105 id: PredicateId::new(3),
1106 name: "bad_count".into(),
1107 fields: vec![ValueType::Entity, ValueType::U64],
1108 })
1109 .expect("register recursive aggregate predicate");
1110 let recursive = DefaultRuleCompiler
1111 .compile(
1112 &recursive_schema,
1113 &RuleProgram {
1114 predicates: vec![edge, predicate(3, "bad_count", 2)],
1115 rules: vec![RuleAst {
1116 id: RuleId::new(2),
1117 head: Atom {
1118 predicate: predicate(3, "bad_count", 2),
1119 terms: vec![
1120 Term::Variable(Variable::new("x")),
1121 aggregate(AggregateFunction::Count, "y"),
1122 ],
1123 },
1124 body: vec![Literal::Positive(atom(
1125 predicate(3, "bad_count", 2),
1126 &["x", "y"],
1127 ))],
1128 }],
1129 materialized: vec![PredicateId::new(3)],
1130 facts: Vec::new(),
1131 },
1132 )
1133 .expect_err("recursive aggregate should fail");
1134 assert!(matches!(
1135 recursive,
1136 CompileError::RecursiveAggregation { rule_id, predicate }
1137 if rule_id == RuleId::new(2) && predicate == "bad_count"
1138 ));
1139 }
1140
1141 #[test]
1142 fn bounded_aggregation_allows_multiple_head_aggregates() {
1143 let project_task = predicate(1, "project_task", 2);
1144 let task_hours = predicate(2, "task_hours", 2);
1145 let project_stats = predicate(3, "project_stats", 3);
1146 let mut schema = Schema::new("v1");
1147 for signature in [
1148 PredicateSignature {
1149 id: project_task.id,
1150 name: project_task.name.clone(),
1151 fields: vec![ValueType::Entity, ValueType::Entity],
1152 },
1153 PredicateSignature {
1154 id: task_hours.id,
1155 name: task_hours.name.clone(),
1156 fields: vec![ValueType::Entity, ValueType::U64],
1157 },
1158 PredicateSignature {
1159 id: project_stats.id,
1160 name: project_stats.name.clone(),
1161 fields: vec![ValueType::Entity, ValueType::U64, ValueType::U64],
1162 },
1163 ] {
1164 schema
1165 .register_predicate(signature)
1166 .expect("register predicate");
1167 }
1168
1169 DefaultRuleCompiler
1170 .compile(
1171 &schema,
1172 &RuleProgram {
1173 predicates: vec![
1174 project_task.clone(),
1175 task_hours.clone(),
1176 project_stats.clone(),
1177 ],
1178 rules: vec![RuleAst {
1179 id: RuleId::new(10),
1180 head: Atom {
1181 predicate: project_stats,
1182 terms: vec![
1183 Term::Variable(Variable::new("project")),
1184 aggregate(AggregateFunction::Count, "task"),
1185 aggregate(AggregateFunction::Sum, "hours"),
1186 ],
1187 },
1188 body: vec![
1189 Literal::Positive(atom(project_task, &["project", "task"])),
1190 Literal::Positive(atom(task_hours, &["task", "hours"])),
1191 ],
1192 }],
1193 materialized: vec![PredicateId::new(3)],
1194 facts: Vec::new(),
1195 },
1196 )
1197 .expect("compile multiple head aggregates");
1198 }
1199
1200 #[test]
1201 fn extensional_binding_rejects_type_mismatches() {
1202 let task_depends_on = predicate(10, "task_depends_on", 2);
1203 let mut schema = Schema::new("v1");
1204 schema
1205 .register_predicate(PredicateSignature {
1206 id: task_depends_on.id,
1207 name: task_depends_on.name.clone(),
1208 fields: vec![ValueType::String, ValueType::Entity],
1209 })
1210 .expect("register predicate");
1211 schema
1212 .register_attribute(AttributeSchema {
1213 id: AttributeId::new(21),
1214 name: "task.depends_on".into(),
1215 class: AttributeClass::RefSet,
1216 value_type: ValueType::Entity,
1217 })
1218 .expect("register attribute");
1219
1220 let error = DefaultRuleCompiler
1221 .compile(
1222 &schema,
1223 &RuleProgram {
1224 predicates: vec![task_depends_on],
1225 rules: Vec::new(),
1226 materialized: Vec::new(),
1227 facts: Vec::new(),
1228 },
1229 )
1230 .expect_err("type-mismatched binding should fail");
1231
1232 assert!(matches!(
1233 error,
1234 CompileError::IncompatibleExtensionalBinding {
1235 predicate,
1236 attribute,
1237 expected_fields,
1238 actual_fields,
1239 } if predicate == "task_depends_on"
1240 && attribute == "task.depends_on"
1241 && expected_fields == vec![ValueType::Entity, ValueType::Entity]
1242 && actual_fields == vec![ValueType::String, ValueType::Entity]
1243 ));
1244 }
1245
1246 #[test]
1247 fn unsafe_variables_are_rejected() {
1248 let ready = predicate(1, "ready", 1);
1249 let edge = predicate(2, "edge", 2);
1250 let schema = schema(&[(1, "ready", 1), (2, "edge", 2)]);
1251 let program = RuleProgram {
1252 predicates: vec![ready.clone(), edge.clone()],
1253 rules: vec![RuleAst {
1254 id: RuleId::new(7),
1255 head: atom(ready, &["x"]),
1256 body: vec![Literal::Positive(atom(edge, &["y", "z"]))],
1257 }],
1258 materialized: Vec::new(),
1259 facts: Vec::new(),
1260 };
1261
1262 let error = DefaultRuleCompiler
1263 .compile(&schema, &program)
1264 .expect_err("unsafe rule should fail");
1265 assert!(matches!(
1266 error,
1267 CompileError::UnsafeVariable { variable, .. } if variable == "x"
1268 ));
1269 }
1270
1271 #[test]
1272 fn unstratified_negation_in_recursive_component_is_rejected() {
1273 let p = predicate(1, "p", 1);
1274 let q = predicate(2, "q", 1);
1275 let schema = schema(&[(1, "p", 1), (2, "q", 1)]);
1276 let program = RuleProgram {
1277 predicates: vec![p.clone(), q.clone()],
1278 rules: vec![
1279 RuleAst {
1280 id: RuleId::new(1),
1281 head: atom(p.clone(), &["x"]),
1282 body: vec![Literal::Positive(atom(q.clone(), &["x"]))],
1283 },
1284 RuleAst {
1285 id: RuleId::new(2),
1286 head: atom(q.clone(), &["x"]),
1287 body: vec![
1288 Literal::Positive(atom(p.clone(), &["x"])),
1289 Literal::Negative(atom(p, &["x"])),
1290 ],
1291 },
1292 ],
1293 materialized: Vec::new(),
1294 facts: Vec::new(),
1295 };
1296
1297 let error = DefaultRuleCompiler
1298 .compile(&schema, &program)
1299 .expect_err("unstratified negation should fail");
1300 assert!(matches!(
1301 error,
1302 CompileError::UnstratifiedNegation { depender, dependency }
1303 if depender == "q" && dependency == "p"
1304 ));
1305 }
1306
1307 #[test]
1308 fn stratified_negation_assigns_higher_strata() {
1309 let task = predicate(1, "task", 1);
1310 let task_status = predicate(2, "task_status", 2);
1311 let task_complete = predicate(3, "task_complete", 1);
1312 let task_ready = predicate(4, "task_ready", 1);
1313 let mut schema = Schema::new("v1");
1314 for signature in [
1315 PredicateSignature {
1316 id: task.id,
1317 name: task.name.clone(),
1318 fields: vec![ValueType::Entity],
1319 },
1320 PredicateSignature {
1321 id: task_status.id,
1322 name: task_status.name.clone(),
1323 fields: vec![ValueType::Entity, ValueType::String],
1324 },
1325 PredicateSignature {
1326 id: task_complete.id,
1327 name: task_complete.name.clone(),
1328 fields: vec![ValueType::Entity],
1329 },
1330 PredicateSignature {
1331 id: task_ready.id,
1332 name: task_ready.name.clone(),
1333 fields: vec![ValueType::Entity],
1334 },
1335 ] {
1336 schema
1337 .register_predicate(signature)
1338 .expect("register predicate");
1339 }
1340 schema
1341 .register_attribute(AttributeSchema {
1342 id: AttributeId::new(20),
1343 name: "task.status".into(),
1344 class: AttributeClass::ScalarLww,
1345 value_type: ValueType::String,
1346 })
1347 .expect("register attribute");
1348
1349 let compiled = DefaultRuleCompiler
1350 .compile(
1351 &schema,
1352 &RuleProgram {
1353 predicates: vec![
1354 task.clone(),
1355 task_status.clone(),
1356 task_complete.clone(),
1357 task_ready.clone(),
1358 ],
1359 rules: vec![
1360 RuleAst {
1361 id: RuleId::new(1),
1362 head: atom(task_complete.clone(), &["x"]),
1363 body: vec![Literal::Positive(Atom {
1364 predicate: task_status.clone(),
1365 terms: vec![
1366 Term::Variable(Variable::new("x")),
1367 Term::Value(Value::String("done".into())),
1368 ],
1369 })],
1370 },
1371 RuleAst {
1372 id: RuleId::new(2),
1373 head: atom(task_ready.clone(), &["x"]),
1374 body: vec![
1375 Literal::Positive(atom(task.clone(), &["x"])),
1376 Literal::Negative(atom(task_complete.clone(), &["x"])),
1377 ],
1378 },
1379 ],
1380 materialized: vec![task_ready.id],
1381 facts: vec![ExtensionalFact {
1382 predicate: task,
1383 values: vec![Value::Entity(aether_ast::EntityId::new(1))],
1384 policy: None,
1385 provenance: None,
1386 }],
1387 },
1388 )
1389 .expect("compile stratified program");
1390
1391 assert_eq!(
1392 compiled.predicate_strata.get(&task_complete.id).copied(),
1393 Some(0)
1394 );
1395 assert_eq!(
1396 compiled.predicate_strata.get(&task_ready.id).copied(),
1397 Some(1)
1398 );
1399 assert_eq!(compiled.facts.len(), 1);
1400 }
1401}