aether_rules/
lib.rs

1mod parser;
2
3pub use parser::{DefaultDslParser, DslDocument, DslParser, ParseError};
4
5use aether_ast::{
6    AggregateFunction, AggregateTerm, Atom, AttributeId, ExtensionalFact, Literal, PredicateId,
7    RuleAst, RuleId, RuleProgram, Term, Value, Variable,
8};
9use aether_plan::{CompiledProgram, DeltaRulePlan, DependencyGraph, StronglyConnectedComponent};
10use aether_schema::{AttributeSchema, Schema, SchemaError, ValueType};
11use indexmap::{IndexMap, IndexSet};
12use thiserror::Error;
13
14pub trait RuleCompiler {
15    fn compile(
16        &self,
17        schema: &Schema,
18        program: &RuleProgram,
19    ) -> Result<CompiledProgram, CompileError>;
20}
21
22#[derive(Default)]
23pub struct DefaultRuleCompiler;
24
25impl RuleCompiler for DefaultRuleCompiler {
26    fn compile(
27        &self,
28        schema: &Schema,
29        program: &RuleProgram,
30    ) -> Result<CompiledProgram, CompileError> {
31        let mut dependency_graph = DependencyGraph::default();
32        let mut all_predicates = IndexSet::new();
33        let mut negative_edges = Vec::new();
34        let mut delta_plans = Vec::new();
35
36        for predicate in &program.predicates {
37            schema.validate_predicate_arity(&predicate.id, predicate.arity)?;
38            all_predicates.insert(predicate.id);
39        }
40
41        for fact in &program.facts {
42            validate_fact(schema, fact)?;
43            all_predicates.insert(fact.predicate.id);
44        }
45
46        for rule in &program.rules {
47            validate_atom(schema, &rule.head)?;
48            all_predicates.insert(rule.head.predicate.id);
49
50            let positive_variables = positive_variables(rule);
51            validate_rule_safety(rule, &positive_variables)?;
52            validate_rule_types_and_aggregates(schema, rule, &positive_variables)?;
53
54            let mut source_predicates = Vec::new();
55            for literal in &rule.body {
56                let atom = literal_atom(literal);
57                validate_atom(schema, atom)?;
58                all_predicates.insert(atom.predicate.id);
59
60                match literal {
61                    Literal::Positive(atom) => {
62                        dependency_graph.add_edge(rule.head.predicate.id, atom.predicate.id);
63                        source_predicates.push(atom.predicate.id);
64                    }
65                    Literal::Negative(atom) => {
66                        negative_edges.push((rule.head.predicate.id, atom.predicate.id));
67                    }
68                }
69            }
70
71            delta_plans.push(DeltaRulePlan {
72                rule_id: rule.id,
73                target_predicate: rule.head.predicate.id,
74                source_predicates,
75            });
76        }
77
78        for predicate in &all_predicates {
79            dependency_graph.edges.entry(*predicate).or_default();
80        }
81
82        let sccs = compute_sccs(&dependency_graph, &all_predicates);
83        let scc_lookup = build_scc_lookup(&sccs);
84        validate_recursive_aggregation(schema, program, &dependency_graph, &sccs, &scc_lookup)?;
85        for (head, dependency) in &negative_edges {
86            if scc_lookup.get(head) == scc_lookup.get(dependency) {
87                return Err(CompileError::UnstratifiedNegation {
88                    depender: predicate_label(schema, *head),
89                    dependency: predicate_label(schema, *dependency),
90                });
91            }
92        }
93        let predicate_strata =
94            compute_predicate_strata(schema, &dependency_graph, &scc_lookup, &negative_edges)?;
95
96        let phase_graph = build_phase_graph(schema, &dependency_graph, &sccs, &scc_lookup);
97        let extensional_bindings = infer_extensional_bindings(schema, program)?;
98
99        Ok(CompiledProgram {
100            dependency_graph,
101            sccs,
102            phase_graph,
103            delta_plans,
104            materialized: program.materialized.clone(),
105            rules: program.rules.clone(),
106            extensional_bindings,
107            facts: program.facts.clone(),
108            predicate_strata,
109        })
110    }
111}
112
113fn validate_atom(schema: &Schema, atom: &Atom) -> Result<(), CompileError> {
114    schema.validate_predicate_arity(&atom.predicate.id, atom.terms.len())?;
115    Ok(())
116}
117
118fn validate_fact(schema: &Schema, fact: &ExtensionalFact) -> Result<(), CompileError> {
119    schema.validate_predicate_arity(&fact.predicate.id, fact.values.len())?;
120    let signature = schema
121        .predicate(&fact.predicate.id)
122        .ok_or(SchemaError::UnknownPredicate(fact.predicate.id))?;
123    for (value, expected) in fact.values.iter().zip(&signature.fields) {
124        if !value_matches_type(value, expected) {
125            return Err(CompileError::FactTypeMismatch {
126                predicate: fact.predicate.name.clone(),
127                expected: signature.fields.clone(),
128                actual: fact.values.iter().map(value_type_of).collect(),
129            });
130        }
131    }
132    Ok(())
133}
134
135fn positive_variables(rule: &RuleAst) -> IndexSet<Variable> {
136    let mut variables = IndexSet::new();
137    for literal in &rule.body {
138        if let Literal::Positive(atom) = literal {
139            variables.extend(atom_variables(atom));
140        }
141    }
142    variables
143}
144
145fn validate_rule_safety(
146    rule: &RuleAst,
147    positive_variables: &IndexSet<Variable>,
148) -> Result<(), CompileError> {
149    for variable in atom_variables(&rule.head) {
150        if !positive_variables.contains(&variable) {
151            return Err(CompileError::UnsafeVariable {
152                rule_id: rule.id,
153                variable: variable.0,
154            });
155        }
156    }
157
158    for literal in &rule.body {
159        if let Literal::Negative(atom) = literal {
160            for variable in atom_variables(atom) {
161                if !positive_variables.contains(&variable) {
162                    return Err(CompileError::UnsafeVariable {
163                        rule_id: rule.id,
164                        variable: variable.0,
165                    });
166                }
167            }
168        }
169    }
170
171    Ok(())
172}
173
174fn atom_variables(atom: &Atom) -> IndexSet<Variable> {
175    atom.terms
176        .iter()
177        .filter_map(|term| match term {
178            Term::Variable(variable) => Some(variable.clone()),
179            Term::Aggregate(aggregate) => Some(aggregate.variable.clone()),
180            Term::Value(_) => None,
181        })
182        .collect()
183}
184
185fn literal_atom(literal: &Literal) -> &Atom {
186    match literal {
187        Literal::Positive(atom) | Literal::Negative(atom) => atom,
188    }
189}
190
191fn validate_rule_types_and_aggregates(
192    schema: &Schema,
193    rule: &RuleAst,
194    positive_variables: &IndexSet<Variable>,
195) -> Result<(), CompileError> {
196    let variable_types = infer_rule_variable_types(schema, rule)?;
197    let signature = schema
198        .predicate(&rule.head.predicate.id)
199        .expect("validated rule head predicate is present in schema");
200    let aggregates = head_aggregates(rule);
201
202    for (_, aggregate_term) in &aggregates {
203        if rule.head.terms.iter().any(
204            |term| matches!(term, Term::Variable(variable) if variable == &aggregate_term.variable),
205        ) {
206            return Err(CompileError::AggregateVariableInGroupKey {
207                rule_id: rule.id,
208                variable: aggregate_term.variable.0.clone(),
209            });
210        }
211    }
212
213    for (position, (term, expected)) in rule.head.terms.iter().zip(&signature.fields).enumerate() {
214        match term {
215            Term::Variable(variable) => {
216                let actual = variable_types.get(variable).ok_or_else(|| {
217                    CompileError::RuleVariableTypeUnknown {
218                        rule_id: rule.id,
219                        variable: variable.0.clone(),
220                    }
221                })?;
222                if actual != expected {
223                    return Err(CompileError::RuleTermTypeMismatch {
224                        rule_id: rule.id,
225                        predicate: signature.name.clone(),
226                        position,
227                        expected: expected.clone(),
228                        actual: actual.clone(),
229                    });
230                }
231            }
232            Term::Value(value) => {
233                if !value_matches_type(value, expected) {
234                    return Err(CompileError::RuleTermTypeMismatch {
235                        rule_id: rule.id,
236                        predicate: signature.name.clone(),
237                        position,
238                        expected: expected.clone(),
239                        actual: value_type_of(value),
240                    });
241                }
242            }
243            Term::Aggregate(aggregate_term) => {
244                if !positive_variables.contains(&aggregate_term.variable) {
245                    return Err(CompileError::UnsafeVariable {
246                        rule_id: rule.id,
247                        variable: aggregate_term.variable.0.clone(),
248                    });
249                }
250                let input_type = variable_types
251                    .get(&aggregate_term.variable)
252                    .ok_or_else(|| CompileError::RuleVariableTypeUnknown {
253                        rule_id: rule.id,
254                        variable: aggregate_term.variable.0.clone(),
255                    })?;
256                let output_type =
257                    aggregate_output_type(rule.id, aggregate_term, input_type.clone())?;
258                if &output_type != expected {
259                    return Err(CompileError::AggregateOutputTypeMismatch {
260                        rule_id: rule.id,
261                        function: aggregate_term.function,
262                        expected: expected.clone(),
263                        actual: output_type,
264                    });
265                }
266            }
267        }
268    }
269
270    Ok(())
271}
272
273fn infer_rule_variable_types(
274    schema: &Schema,
275    rule: &RuleAst,
276) -> Result<IndexMap<Variable, ValueType>, CompileError> {
277    let mut variable_types = IndexMap::new();
278    validate_atom_term_types(schema, rule.id, &rule.head, true, &mut variable_types)?;
279    for literal in &rule.body {
280        validate_atom_term_types(
281            schema,
282            rule.id,
283            literal_atom(literal),
284            false,
285            &mut variable_types,
286        )?;
287    }
288    Ok(variable_types)
289}
290
291fn validate_atom_term_types(
292    schema: &Schema,
293    rule_id: RuleId,
294    atom: &Atom,
295    allow_head_aggregate: bool,
296    variable_types: &mut IndexMap<Variable, ValueType>,
297) -> Result<(), CompileError> {
298    let signature = schema
299        .predicate(&atom.predicate.id)
300        .expect("validated atom predicate is present in schema");
301
302    for (position, (term, expected)) in atom.terms.iter().zip(&signature.fields).enumerate() {
303        match term {
304            Term::Variable(variable) => {
305                if let Some(existing) = variable_types.get(variable) {
306                    if existing != expected {
307                        return Err(CompileError::RuleVariableTypeConflict {
308                            rule_id,
309                            variable: variable.0.clone(),
310                            first: existing.clone(),
311                            second: expected.clone(),
312                        });
313                    }
314                } else {
315                    variable_types.insert(variable.clone(), expected.clone());
316                }
317            }
318            Term::Value(value) => {
319                if !value_matches_type(value, expected) {
320                    return Err(CompileError::RuleTermTypeMismatch {
321                        rule_id,
322                        predicate: signature.name.clone(),
323                        position,
324                        expected: expected.clone(),
325                        actual: value_type_of(value),
326                    });
327                }
328            }
329            Term::Aggregate(_) if !allow_head_aggregate => {
330                return Err(CompileError::AggregateOutsideHead { rule_id });
331            }
332            Term::Aggregate(_) => {}
333        }
334    }
335
336    Ok(())
337}
338
339fn head_aggregates(rule: &RuleAst) -> Vec<(usize, &AggregateTerm)> {
340    rule.head
341        .terms
342        .iter()
343        .enumerate()
344        .filter_map(|(index, term)| match term {
345            Term::Aggregate(aggregate_term) => Some((index, aggregate_term)),
346            _ => None,
347        })
348        .collect()
349}
350
351fn aggregate_output_type(
352    rule_id: RuleId,
353    aggregate: &AggregateTerm,
354    input_type: ValueType,
355) -> Result<ValueType, CompileError> {
356    match aggregate.function {
357        AggregateFunction::Count => Ok(ValueType::U64),
358        AggregateFunction::Sum => match input_type {
359            ValueType::I64 | ValueType::U64 | ValueType::F64 => Ok(input_type),
360            other => Err(CompileError::UnsupportedAggregateInputType {
361                rule_id,
362                function: aggregate.function,
363                variable: aggregate.variable.0.clone(),
364                input_type: other,
365            }),
366        },
367        AggregateFunction::Min | AggregateFunction::Max => match input_type {
368            ValueType::I64
369            | ValueType::U64
370            | ValueType::F64
371            | ValueType::String
372            | ValueType::Entity => Ok(input_type),
373            other => Err(CompileError::UnsupportedAggregateInputType {
374                rule_id,
375                function: aggregate.function,
376                variable: aggregate.variable.0.clone(),
377                input_type: other,
378            }),
379        },
380    }
381}
382
383fn validate_recursive_aggregation(
384    schema: &Schema,
385    program: &RuleProgram,
386    graph: &DependencyGraph,
387    sccs: &[StronglyConnectedComponent],
388    scc_lookup: &IndexMap<PredicateId, usize>,
389) -> Result<(), CompileError> {
390    for rule in &program.rules {
391        if head_aggregates(rule).is_empty() {
392            continue;
393        }
394
395        let scc_id = *scc_lookup
396            .get(&rule.head.predicate.id)
397            .expect("aggregate head predicate should be present in scc lookup");
398        let recursive = sccs.iter().find(|scc| scc.id == scc_id).is_some_and(|scc| {
399            scc.predicates.len() > 1
400                || scc.predicates.iter().any(|predicate| {
401                    graph
402                        .edges
403                        .get(predicate)
404                        .is_some_and(|deps| deps.contains(predicate))
405                })
406        });
407        if recursive {
408            return Err(CompileError::RecursiveAggregation {
409                rule_id: rule.id,
410                predicate: predicate_label(schema, rule.head.predicate.id),
411            });
412        }
413    }
414
415    Ok(())
416}
417
418fn compute_sccs(
419    graph: &DependencyGraph,
420    predicates: &IndexSet<PredicateId>,
421) -> Vec<StronglyConnectedComponent> {
422    let mut visited = IndexSet::new();
423    let mut order = Vec::new();
424
425    for predicate in predicates {
426        dfs_forward(*predicate, graph, &mut visited, &mut order);
427    }
428
429    let reversed = reverse_graph(graph, predicates);
430    visited.clear();
431
432    let mut sccs = Vec::new();
433    let mut next_id = 0usize;
434    while let Some(predicate) = order.pop() {
435        if visited.contains(&predicate) {
436            continue;
437        }
438        let mut component = Vec::new();
439        dfs_reverse(predicate, &reversed, &mut visited, &mut component);
440        component.sort();
441        sccs.push(StronglyConnectedComponent {
442            id: next_id,
443            predicates: component,
444        });
445        next_id += 1;
446    }
447
448    sccs
449}
450
451fn dfs_forward(
452    start: PredicateId,
453    graph: &DependencyGraph,
454    visited: &mut IndexSet<PredicateId>,
455    order: &mut Vec<PredicateId>,
456) {
457    if !visited.insert(start) {
458        return;
459    }
460
461    if let Some(neighbors) = graph.edges.get(&start) {
462        for neighbor in neighbors {
463            dfs_forward(*neighbor, graph, visited, order);
464        }
465    }
466
467    order.push(start);
468}
469
470fn reverse_graph(
471    graph: &DependencyGraph,
472    predicates: &IndexSet<PredicateId>,
473) -> IndexMap<PredicateId, Vec<PredicateId>> {
474    let mut reversed: IndexMap<PredicateId, Vec<PredicateId>> = predicates
475        .iter()
476        .map(|predicate| (*predicate, Vec::new()))
477        .collect();
478
479    for (head, dependencies) in &graph.edges {
480        for dependency in dependencies {
481            reversed.entry(*dependency).or_default().push(*head);
482        }
483    }
484
485    reversed
486}
487
488fn dfs_reverse(
489    start: PredicateId,
490    graph: &IndexMap<PredicateId, Vec<PredicateId>>,
491    visited: &mut IndexSet<PredicateId>,
492    component: &mut Vec<PredicateId>,
493) {
494    if !visited.insert(start) {
495        return;
496    }
497
498    component.push(start);
499    if let Some(neighbors) = graph.get(&start) {
500        for neighbor in neighbors {
501            dfs_reverse(*neighbor, graph, visited, component);
502        }
503    }
504}
505
506fn build_scc_lookup(sccs: &[StronglyConnectedComponent]) -> IndexMap<PredicateId, usize> {
507    let mut lookup = IndexMap::new();
508    for scc in sccs {
509        for predicate in &scc.predicates {
510            lookup.insert(*predicate, scc.id);
511        }
512    }
513    lookup
514}
515
516fn build_phase_graph(
517    schema: &Schema,
518    graph: &DependencyGraph,
519    sccs: &[StronglyConnectedComponent],
520    scc_lookup: &IndexMap<PredicateId, usize>,
521) -> aether_ast::PhaseGraph {
522    let mut nodes = Vec::new();
523    let mut edges = IndexSet::new();
524
525    for scc in sccs {
526        let provides: Vec<String> = scc
527            .predicates
528            .iter()
529            .map(|predicate| predicate_label(schema, *predicate))
530            .collect();
531        let mut available = Vec::new();
532
533        for predicate in &scc.predicates {
534            if let Some(dependencies) = graph.edges.get(predicate) {
535                for dependency in dependencies {
536                    let dependency_scc = *scc_lookup
537                        .get(dependency)
538                        .expect("predicate present in scc lookup");
539                    if dependency_scc != scc.id {
540                        available.push(predicate_label(schema, *dependency));
541                        edges.insert((dependency_scc, scc.id));
542                    }
543                }
544            }
545        }
546
547        available.sort();
548        available.dedup();
549
550        let recursive = scc.predicates.len() > 1
551            || scc.predicates.iter().any(|predicate| {
552                graph
553                    .edges
554                    .get(predicate)
555                    .is_some_and(|deps| deps.contains(predicate))
556            });
557
558        nodes.push(aether_ast::PhaseNode {
559            id: format!("scc-{}", scc.id),
560            signature: aether_ast::PhaseSignature {
561                available,
562                provides: provides.clone(),
563                keep: provides,
564            },
565            recursive_scc: recursive.then_some(scc.id),
566        });
567    }
568
569    let edges = edges
570        .into_iter()
571        .map(|(from, to)| aether_ast::PhaseEdge {
572            from: format!("scc-{}", from),
573            to: format!("scc-{}", to),
574        })
575        .collect();
576
577    aether_ast::PhaseGraph { nodes, edges }
578}
579
580fn compute_predicate_strata(
581    _schema: &Schema,
582    graph: &DependencyGraph,
583    scc_lookup: &IndexMap<PredicateId, usize>,
584    negative_edges: &[(PredicateId, PredicateId)],
585) -> Result<IndexMap<PredicateId, usize>, CompileError> {
586    let mut condensed_edges: IndexMap<(usize, usize), usize> = IndexMap::new();
587    let mut scc_ids = IndexSet::new();
588    for scc_id in scc_lookup.values() {
589        scc_ids.insert(*scc_id);
590    }
591
592    for (head, dependencies) in &graph.edges {
593        let to = *scc_lookup
594            .get(head)
595            .expect("head predicate should be present in scc lookup");
596        for dependency in dependencies {
597            let from = *scc_lookup
598                .get(dependency)
599                .expect("dependency predicate should be present in scc lookup");
600            if from != to {
601                scc_ids.insert(from);
602                scc_ids.insert(to);
603                condensed_edges.entry((from, to)).or_insert(0);
604            }
605        }
606    }
607
608    for (head, dependency) in negative_edges {
609        let to = *scc_lookup
610            .get(head)
611            .expect("negative head predicate should be present in scc lookup");
612        let from = *scc_lookup
613            .get(dependency)
614            .expect("negative dependency predicate should be present in scc lookup");
615        if from != to {
616            scc_ids.insert(from);
617            scc_ids.insert(to);
618            condensed_edges
619                .entry((from, to))
620                .and_modify(|weight| *weight = (*weight).max(1))
621                .or_insert(1);
622        }
623    }
624
625    let mut outgoing: IndexMap<usize, Vec<(usize, usize)>> = scc_ids
626        .iter()
627        .copied()
628        .map(|scc_id| (scc_id, Vec::new()))
629        .collect();
630    let mut indegree: IndexMap<usize, usize> = scc_ids
631        .iter()
632        .copied()
633        .map(|scc_id| (scc_id, 0usize))
634        .collect();
635
636    for ((from, to), weight) in condensed_edges {
637        outgoing.entry(from).or_default().push((to, weight));
638        *indegree.entry(to).or_default() += 1;
639    }
640
641    let mut ready = indegree
642        .iter()
643        .filter_map(|(scc_id, degree)| (*degree == 0).then_some(*scc_id))
644        .collect::<Vec<_>>();
645    ready.sort_unstable();
646
647    let mut order = Vec::new();
648    while let Some(scc_id) = ready.first().copied() {
649        ready.remove(0);
650        order.push(scc_id);
651        if let Some(edges) = outgoing.get(&scc_id) {
652            for (to, _) in edges {
653                let degree = indegree
654                    .get_mut(to)
655                    .expect("target scc should have indegree entry");
656                *degree -= 1;
657                if *degree == 0 {
658                    ready.push(*to);
659                    ready.sort_unstable();
660                }
661            }
662        }
663    }
664
665    if order.len() != indegree.len() {
666        return Err(CompileError::UnstratifiedNegation {
667            depender: "program".into(),
668            dependency: "negative cycle".into(),
669        });
670    }
671
672    let mut scc_strata: IndexMap<usize, usize> = scc_ids
673        .iter()
674        .copied()
675        .map(|scc_id| (scc_id, 0usize))
676        .collect();
677    for scc_id in order {
678        let current = *scc_strata
679            .get(&scc_id)
680            .expect("source scc should have a stratum");
681        if let Some(edges) = outgoing.get(&scc_id) {
682            for (to, weight) in edges {
683                let target = scc_strata
684                    .get_mut(to)
685                    .expect("target scc should have a stratum");
686                *target = (*target).max(current + *weight);
687            }
688        }
689    }
690
691    Ok(scc_lookup
692        .iter()
693        .map(|(predicate, scc_id)| {
694            (
695                *predicate,
696                *scc_strata
697                    .get(scc_id)
698                    .expect("predicate scc should have a stratum"),
699            )
700        })
701        .collect())
702}
703
704fn infer_extensional_bindings(
705    schema: &Schema,
706    program: &RuleProgram,
707) -> Result<IndexMap<PredicateId, AttributeId>, CompileError> {
708    let mut bindings = IndexMap::new();
709
710    for predicate in &program.predicates {
711        if predicate.arity != 2 {
712            continue;
713        }
714
715        if let Some(attribute) = matching_attribute(schema, &predicate.name) {
716            validate_extensional_binding(schema, predicate.id, attribute)?;
717            bindings.insert(predicate.id, attribute.id);
718        }
719    }
720
721    Ok(bindings)
722}
723
724fn matching_attribute<'a>(schema: &'a Schema, predicate_name: &str) -> Option<&'a AttributeSchema> {
725    let mut candidates = vec![predicate_name.to_owned()];
726    if predicate_name.contains('_') {
727        candidates.push(predicate_name.replacen('_', ".", 1));
728        candidates.push(predicate_name.replace('_', "."));
729    }
730
731    candidates.dedup();
732
733    candidates.into_iter().find_map(|candidate| {
734        schema
735            .attributes
736            .values()
737            .find(|attribute| attribute.name == candidate)
738    })
739}
740
741fn validate_extensional_binding(
742    schema: &Schema,
743    predicate: PredicateId,
744    attribute: &AttributeSchema,
745) -> Result<(), CompileError> {
746    let signature = schema
747        .predicate(&predicate)
748        .expect("validated predicates are present in schema");
749    let expected_fields = vec![ValueType::Entity, attribute.value_type.clone()];
750
751    if signature.fields != expected_fields {
752        return Err(CompileError::IncompatibleExtensionalBinding {
753            predicate: signature.name.clone(),
754            attribute: attribute.name.clone(),
755            expected_fields,
756            actual_fields: signature.fields.clone(),
757        });
758    }
759
760    Ok(())
761}
762
763fn predicate_label(schema: &Schema, predicate: PredicateId) -> String {
764    schema
765        .predicate(&predicate)
766        .map(|signature| signature.name.clone())
767        .unwrap_or_else(|| format!("predicate-{}", predicate))
768}
769
770fn value_matches_type(value: &Value, expected: &ValueType) -> bool {
771    match (value, expected) {
772        (Value::Null, _) => true,
773        (Value::Bool(_), ValueType::Bool) => true,
774        (Value::I64(_), ValueType::I64) => true,
775        (Value::U64(_), ValueType::U64) => true,
776        (Value::F64(_), ValueType::F64) => true,
777        (Value::String(_), ValueType::String) => true,
778        (Value::Bytes(_), ValueType::Bytes) => true,
779        (Value::Entity(_), ValueType::Entity) => true,
780        (Value::List(values), ValueType::List(inner)) => {
781            values.iter().all(|value| value_matches_type(value, inner))
782        }
783        _ => false,
784    }
785}
786
787fn value_type_of(value: &Value) -> ValueType {
788    match value {
789        Value::Null => ValueType::String,
790        Value::Bool(_) => ValueType::Bool,
791        Value::I64(_) => ValueType::I64,
792        Value::U64(_) => ValueType::U64,
793        Value::F64(_) => ValueType::F64,
794        Value::String(_) => ValueType::String,
795        Value::Bytes(_) => ValueType::Bytes,
796        Value::Entity(_) => ValueType::Entity,
797        Value::List(values) => ValueType::List(Box::new(
798            values
799                .first()
800                .map(value_type_of)
801                .unwrap_or(ValueType::String),
802        )),
803    }
804}
805
806#[derive(Debug, Error)]
807pub enum CompileError {
808    #[error(transparent)]
809    Schema(#[from] SchemaError),
810    #[error("rule {rule_id} uses aggregate terms outside a rule head")]
811    AggregateOutsideHead { rule_id: RuleId },
812    #[error("rule {rule_id} cannot group by aggregate variable {variable}")]
813    AggregateVariableInGroupKey { rule_id: RuleId, variable: String },
814    #[error(
815        "rule {rule_id} uses aggregate {function} over variable {variable} with unsupported input type {input_type:?}"
816    )]
817    UnsupportedAggregateInputType {
818        rule_id: RuleId,
819        function: AggregateFunction,
820        variable: String,
821        input_type: ValueType,
822    },
823    #[error(
824        "rule {rule_id} produces aggregate {function} with type {actual:?}, but the head expects {expected:?}"
825    )]
826    AggregateOutputTypeMismatch {
827        rule_id: RuleId,
828        function: AggregateFunction,
829        expected: ValueType,
830        actual: ValueType,
831    },
832    #[error(
833        "rule {rule_id} uses variable {variable} with incompatible types {first:?} and {second:?}"
834    )]
835    RuleVariableTypeConflict {
836        rule_id: RuleId,
837        variable: String,
838        first: ValueType,
839        second: ValueType,
840    },
841    #[error("rule {rule_id} references variable {variable}, but its type could not be inferred")]
842    RuleVariableTypeUnknown { rule_id: RuleId, variable: String },
843    #[error(
844        "rule {rule_id} uses term {position} of predicate {predicate} with type {actual:?}, expected {expected:?}"
845    )]
846    RuleTermTypeMismatch {
847        rule_id: RuleId,
848        predicate: String,
849        position: usize,
850        expected: ValueType,
851        actual: ValueType,
852    },
853    #[error("rule {rule_id} uses bounded aggregation recursively through predicate {predicate}")]
854    RecursiveAggregation { rule_id: RuleId, predicate: String },
855    #[error(
856        "predicate {predicate} cannot bind to attribute {attribute}: expected {expected_fields:?}, found {actual_fields:?}"
857    )]
858    IncompatibleExtensionalBinding {
859        predicate: String,
860        attribute: String,
861        expected_fields: Vec<ValueType>,
862        actual_fields: Vec<ValueType>,
863    },
864    #[error(
865        "fact for predicate {predicate} does not match the declared types: expected {expected:?}, found {actual:?}"
866    )]
867    FactTypeMismatch {
868        predicate: String,
869        expected: Vec<ValueType>,
870        actual: Vec<ValueType>,
871    },
872    #[error("rule {rule_id} uses unsafe variable {variable}")]
873    UnsafeVariable {
874        rule_id: aether_ast::RuleId,
875        variable: String,
876    },
877    #[error("unstratified negation detected: {depender} depends negatively on {dependency}")]
878    UnstratifiedNegation {
879        depender: String,
880        dependency: String,
881    },
882}
883
884#[cfg(test)]
885mod tests {
886    use super::{CompileError, DefaultRuleCompiler, RuleCompiler};
887    use aether_ast::{
888        AggregateFunction, AggregateTerm, Atom, AttributeId, ExtensionalFact, Literal, PredicateId,
889        PredicateRef, RuleAst, RuleId, RuleProgram, Term, Value, Variable,
890    };
891    use aether_schema::{AttributeClass, AttributeSchema, PredicateSignature, Schema, ValueType};
892
893    fn predicate(id: u64, name: &str, arity: usize) -> PredicateRef {
894        PredicateRef {
895            id: PredicateId::new(id),
896            name: name.into(),
897            arity,
898        }
899    }
900
901    fn atom(predicate: PredicateRef, vars: &[&str]) -> Atom {
902        Atom {
903            predicate,
904            terms: vars
905                .iter()
906                .map(|name| Term::Variable(Variable::new(*name)))
907                .collect(),
908        }
909    }
910
911    fn aggregate(function: AggregateFunction, variable: &str) -> Term {
912        Term::Aggregate(AggregateTerm {
913            function,
914            variable: Variable::new(variable),
915        })
916    }
917
918    fn schema(predicates: &[(u64, &str, usize)]) -> Schema {
919        let mut schema = Schema::new("v1");
920        for (id, name, arity) in predicates {
921            schema
922                .register_predicate(PredicateSignature {
923                    id: PredicateId::new(*id),
924                    name: (*name).into(),
925                    fields: vec![ValueType::Entity; *arity],
926                })
927                .expect("register predicate");
928        }
929        schema
930    }
931
932    #[test]
933    fn safe_recursive_program_builds_expected_graph_and_phase_boundaries() {
934        let edge = predicate(1, "edge", 2);
935        let reach = predicate(2, "reach", 2);
936        let schema = schema(&[(1, "edge", 2), (2, "reach", 2)]);
937        let program = RuleProgram {
938            predicates: vec![edge.clone(), reach.clone()],
939            rules: vec![
940                RuleAst {
941                    id: RuleId::new(1),
942                    head: atom(reach.clone(), &["x", "y"]),
943                    body: vec![Literal::Positive(atom(edge.clone(), &["x", "y"]))],
944                },
945                RuleAst {
946                    id: RuleId::new(2),
947                    head: atom(reach.clone(), &["x", "z"]),
948                    body: vec![
949                        Literal::Positive(atom(reach.clone(), &["x", "y"])),
950                        Literal::Positive(atom(edge.clone(), &["y", "z"])),
951                    ],
952                },
953            ],
954            materialized: vec![reach.id],
955            facts: Vec::new(),
956        };
957
958        let compiled = DefaultRuleCompiler
959            .compile(&schema, &program)
960            .expect("compile recursive program");
961        let reach_edges = compiled
962            .dependency_graph
963            .edges
964            .get(&reach.id)
965            .expect("reach edges");
966
967        assert!(reach_edges.contains(&edge.id));
968        assert!(reach_edges.contains(&reach.id));
969        assert_eq!(compiled.sccs.len(), 2);
970
971        let reach_scc = compiled
972            .sccs
973            .iter()
974            .find(|scc| scc.predicates.contains(&reach.id))
975            .expect("reach scc");
976        let edge_scc = compiled
977            .sccs
978            .iter()
979            .find(|scc| scc.predicates.contains(&edge.id))
980            .expect("edge scc");
981        let reach_node = compiled
982            .phase_graph
983            .nodes
984            .iter()
985            .find(|node| node.id == format!("scc-{}", reach_scc.id))
986            .expect("reach phase node");
987        let edge_node = compiled
988            .phase_graph
989            .nodes
990            .iter()
991            .find(|node| node.id == format!("scc-{}", edge_scc.id))
992            .expect("edge phase node");
993
994        assert_eq!(reach_node.recursive_scc, Some(reach_scc.id));
995        assert_eq!(edge_node.recursive_scc, None);
996        assert_eq!(compiled.predicate_strata.get(&edge.id).copied(), Some(0));
997        assert_eq!(compiled.predicate_strata.get(&reach.id).copied(), Some(0));
998        assert!(compiled.phase_graph.edges.iter().any(|edge_ref| {
999            edge_ref.from == format!("scc-{}", edge_scc.id)
1000                && edge_ref.to == format!("scc-{}", reach_scc.id)
1001        }));
1002        assert_eq!(compiled.rules, program.rules);
1003    }
1004
1005    #[test]
1006    fn extensional_predicates_bind_to_matching_attribute_names() {
1007        let task_depends_on = predicate(10, "task_depends_on", 2);
1008        let depends_transitive = predicate(11, "depends_transitive", 2);
1009        let mut schema = schema(&[(10, "task_depends_on", 2), (11, "depends_transitive", 2)]);
1010        schema
1011            .register_attribute(AttributeSchema {
1012                id: AttributeId::new(21),
1013                name: "task.depends_on".into(),
1014                class: AttributeClass::RefSet,
1015                value_type: ValueType::Entity,
1016            })
1017            .expect("register attribute");
1018
1019        let compiled = DefaultRuleCompiler
1020            .compile(
1021                &schema,
1022                &RuleProgram {
1023                    predicates: vec![task_depends_on.clone(), depends_transitive.clone()],
1024                    rules: vec![RuleAst {
1025                        id: RuleId::new(1),
1026                        head: atom(depends_transitive, &["x", "y"]),
1027                        body: vec![Literal::Positive(atom(
1028                            task_depends_on.clone(),
1029                            &["x", "y"],
1030                        ))],
1031                    }],
1032                    materialized: vec![task_depends_on.id],
1033                    facts: Vec::new(),
1034                },
1035            )
1036            .expect("compile program");
1037
1038        assert_eq!(
1039            compiled.extensional_bindings.get(&task_depends_on.id),
1040            Some(&AttributeId::new(21))
1041        );
1042    }
1043
1044    #[test]
1045    fn bounded_aggregation_requires_non_recursive_rules_and_matching_output_types() {
1046        let edge = predicate(1, "edge", 2);
1047        let reach_count = predicate(2, "reach_count", 2);
1048        let mut aggregate_schema = Schema::new("v1");
1049        aggregate_schema
1050            .register_predicate(PredicateSignature {
1051                id: edge.id,
1052                name: edge.name.clone(),
1053                fields: vec![ValueType::Entity, ValueType::Entity],
1054            })
1055            .expect("register edge");
1056        aggregate_schema
1057            .register_predicate(PredicateSignature {
1058                id: reach_count.id,
1059                name: reach_count.name.clone(),
1060                fields: vec![ValueType::Entity, ValueType::String],
1061            })
1062            .expect("register aggregate predicate");
1063
1064        let type_mismatch = DefaultRuleCompiler
1065            .compile(
1066                &aggregate_schema,
1067                &RuleProgram {
1068                    predicates: vec![edge.clone(), reach_count.clone()],
1069                    rules: vec![RuleAst {
1070                        id: RuleId::new(1),
1071                        head: Atom {
1072                            predicate: reach_count.clone(),
1073                            terms: vec![
1074                                Term::Variable(Variable::new("x")),
1075                                aggregate(AggregateFunction::Count, "y"),
1076                            ],
1077                        },
1078                        body: vec![Literal::Positive(atom(edge.clone(), &["x", "y"]))],
1079                    }],
1080                    materialized: vec![reach_count.id],
1081                    facts: Vec::new(),
1082                },
1083            )
1084            .expect_err("aggregate output type mismatch should fail");
1085        assert!(matches!(
1086            type_mismatch,
1087            CompileError::AggregateOutputTypeMismatch {
1088                rule_id,
1089                function: AggregateFunction::Count,
1090                expected: ValueType::String,
1091                actual: ValueType::U64,
1092            } if rule_id == RuleId::new(1)
1093        ));
1094
1095        let mut recursive_schema = Schema::new("v1");
1096        recursive_schema
1097            .register_predicate(PredicateSignature {
1098                id: PredicateId::new(1),
1099                name: "edge".into(),
1100                fields: vec![ValueType::Entity, ValueType::Entity],
1101            })
1102            .expect("register edge");
1103        recursive_schema
1104            .register_predicate(PredicateSignature {
1105                id: PredicateId::new(3),
1106                name: "bad_count".into(),
1107                fields: vec![ValueType::Entity, ValueType::U64],
1108            })
1109            .expect("register recursive aggregate predicate");
1110        let recursive = DefaultRuleCompiler
1111            .compile(
1112                &recursive_schema,
1113                &RuleProgram {
1114                    predicates: vec![edge, predicate(3, "bad_count", 2)],
1115                    rules: vec![RuleAst {
1116                        id: RuleId::new(2),
1117                        head: Atom {
1118                            predicate: predicate(3, "bad_count", 2),
1119                            terms: vec![
1120                                Term::Variable(Variable::new("x")),
1121                                aggregate(AggregateFunction::Count, "y"),
1122                            ],
1123                        },
1124                        body: vec![Literal::Positive(atom(
1125                            predicate(3, "bad_count", 2),
1126                            &["x", "y"],
1127                        ))],
1128                    }],
1129                    materialized: vec![PredicateId::new(3)],
1130                    facts: Vec::new(),
1131                },
1132            )
1133            .expect_err("recursive aggregate should fail");
1134        assert!(matches!(
1135            recursive,
1136            CompileError::RecursiveAggregation { rule_id, predicate }
1137                if rule_id == RuleId::new(2) && predicate == "bad_count"
1138        ));
1139    }
1140
1141    #[test]
1142    fn bounded_aggregation_allows_multiple_head_aggregates() {
1143        let project_task = predicate(1, "project_task", 2);
1144        let task_hours = predicate(2, "task_hours", 2);
1145        let project_stats = predicate(3, "project_stats", 3);
1146        let mut schema = Schema::new("v1");
1147        for signature in [
1148            PredicateSignature {
1149                id: project_task.id,
1150                name: project_task.name.clone(),
1151                fields: vec![ValueType::Entity, ValueType::Entity],
1152            },
1153            PredicateSignature {
1154                id: task_hours.id,
1155                name: task_hours.name.clone(),
1156                fields: vec![ValueType::Entity, ValueType::U64],
1157            },
1158            PredicateSignature {
1159                id: project_stats.id,
1160                name: project_stats.name.clone(),
1161                fields: vec![ValueType::Entity, ValueType::U64, ValueType::U64],
1162            },
1163        ] {
1164            schema
1165                .register_predicate(signature)
1166                .expect("register predicate");
1167        }
1168
1169        DefaultRuleCompiler
1170            .compile(
1171                &schema,
1172                &RuleProgram {
1173                    predicates: vec![
1174                        project_task.clone(),
1175                        task_hours.clone(),
1176                        project_stats.clone(),
1177                    ],
1178                    rules: vec![RuleAst {
1179                        id: RuleId::new(10),
1180                        head: Atom {
1181                            predicate: project_stats,
1182                            terms: vec![
1183                                Term::Variable(Variable::new("project")),
1184                                aggregate(AggregateFunction::Count, "task"),
1185                                aggregate(AggregateFunction::Sum, "hours"),
1186                            ],
1187                        },
1188                        body: vec![
1189                            Literal::Positive(atom(project_task, &["project", "task"])),
1190                            Literal::Positive(atom(task_hours, &["task", "hours"])),
1191                        ],
1192                    }],
1193                    materialized: vec![PredicateId::new(3)],
1194                    facts: Vec::new(),
1195                },
1196            )
1197            .expect("compile multiple head aggregates");
1198    }
1199
1200    #[test]
1201    fn extensional_binding_rejects_type_mismatches() {
1202        let task_depends_on = predicate(10, "task_depends_on", 2);
1203        let mut schema = Schema::new("v1");
1204        schema
1205            .register_predicate(PredicateSignature {
1206                id: task_depends_on.id,
1207                name: task_depends_on.name.clone(),
1208                fields: vec![ValueType::String, ValueType::Entity],
1209            })
1210            .expect("register predicate");
1211        schema
1212            .register_attribute(AttributeSchema {
1213                id: AttributeId::new(21),
1214                name: "task.depends_on".into(),
1215                class: AttributeClass::RefSet,
1216                value_type: ValueType::Entity,
1217            })
1218            .expect("register attribute");
1219
1220        let error = DefaultRuleCompiler
1221            .compile(
1222                &schema,
1223                &RuleProgram {
1224                    predicates: vec![task_depends_on],
1225                    rules: Vec::new(),
1226                    materialized: Vec::new(),
1227                    facts: Vec::new(),
1228                },
1229            )
1230            .expect_err("type-mismatched binding should fail");
1231
1232        assert!(matches!(
1233            error,
1234            CompileError::IncompatibleExtensionalBinding {
1235                predicate,
1236                attribute,
1237                expected_fields,
1238                actual_fields,
1239            } if predicate == "task_depends_on"
1240                && attribute == "task.depends_on"
1241                && expected_fields == vec![ValueType::Entity, ValueType::Entity]
1242                && actual_fields == vec![ValueType::String, ValueType::Entity]
1243        ));
1244    }
1245
1246    #[test]
1247    fn unsafe_variables_are_rejected() {
1248        let ready = predicate(1, "ready", 1);
1249        let edge = predicate(2, "edge", 2);
1250        let schema = schema(&[(1, "ready", 1), (2, "edge", 2)]);
1251        let program = RuleProgram {
1252            predicates: vec![ready.clone(), edge.clone()],
1253            rules: vec![RuleAst {
1254                id: RuleId::new(7),
1255                head: atom(ready, &["x"]),
1256                body: vec![Literal::Positive(atom(edge, &["y", "z"]))],
1257            }],
1258            materialized: Vec::new(),
1259            facts: Vec::new(),
1260        };
1261
1262        let error = DefaultRuleCompiler
1263            .compile(&schema, &program)
1264            .expect_err("unsafe rule should fail");
1265        assert!(matches!(
1266            error,
1267            CompileError::UnsafeVariable { variable, .. } if variable == "x"
1268        ));
1269    }
1270
1271    #[test]
1272    fn unstratified_negation_in_recursive_component_is_rejected() {
1273        let p = predicate(1, "p", 1);
1274        let q = predicate(2, "q", 1);
1275        let schema = schema(&[(1, "p", 1), (2, "q", 1)]);
1276        let program = RuleProgram {
1277            predicates: vec![p.clone(), q.clone()],
1278            rules: vec![
1279                RuleAst {
1280                    id: RuleId::new(1),
1281                    head: atom(p.clone(), &["x"]),
1282                    body: vec![Literal::Positive(atom(q.clone(), &["x"]))],
1283                },
1284                RuleAst {
1285                    id: RuleId::new(2),
1286                    head: atom(q.clone(), &["x"]),
1287                    body: vec![
1288                        Literal::Positive(atom(p.clone(), &["x"])),
1289                        Literal::Negative(atom(p, &["x"])),
1290                    ],
1291                },
1292            ],
1293            materialized: Vec::new(),
1294            facts: Vec::new(),
1295        };
1296
1297        let error = DefaultRuleCompiler
1298            .compile(&schema, &program)
1299            .expect_err("unstratified negation should fail");
1300        assert!(matches!(
1301            error,
1302            CompileError::UnstratifiedNegation { depender, dependency }
1303                if depender == "q" && dependency == "p"
1304        ));
1305    }
1306
1307    #[test]
1308    fn stratified_negation_assigns_higher_strata() {
1309        let task = predicate(1, "task", 1);
1310        let task_status = predicate(2, "task_status", 2);
1311        let task_complete = predicate(3, "task_complete", 1);
1312        let task_ready = predicate(4, "task_ready", 1);
1313        let mut schema = Schema::new("v1");
1314        for signature in [
1315            PredicateSignature {
1316                id: task.id,
1317                name: task.name.clone(),
1318                fields: vec![ValueType::Entity],
1319            },
1320            PredicateSignature {
1321                id: task_status.id,
1322                name: task_status.name.clone(),
1323                fields: vec![ValueType::Entity, ValueType::String],
1324            },
1325            PredicateSignature {
1326                id: task_complete.id,
1327                name: task_complete.name.clone(),
1328                fields: vec![ValueType::Entity],
1329            },
1330            PredicateSignature {
1331                id: task_ready.id,
1332                name: task_ready.name.clone(),
1333                fields: vec![ValueType::Entity],
1334            },
1335        ] {
1336            schema
1337                .register_predicate(signature)
1338                .expect("register predicate");
1339        }
1340        schema
1341            .register_attribute(AttributeSchema {
1342                id: AttributeId::new(20),
1343                name: "task.status".into(),
1344                class: AttributeClass::ScalarLww,
1345                value_type: ValueType::String,
1346            })
1347            .expect("register attribute");
1348
1349        let compiled = DefaultRuleCompiler
1350            .compile(
1351                &schema,
1352                &RuleProgram {
1353                    predicates: vec![
1354                        task.clone(),
1355                        task_status.clone(),
1356                        task_complete.clone(),
1357                        task_ready.clone(),
1358                    ],
1359                    rules: vec![
1360                        RuleAst {
1361                            id: RuleId::new(1),
1362                            head: atom(task_complete.clone(), &["x"]),
1363                            body: vec![Literal::Positive(Atom {
1364                                predicate: task_status.clone(),
1365                                terms: vec![
1366                                    Term::Variable(Variable::new("x")),
1367                                    Term::Value(Value::String("done".into())),
1368                                ],
1369                            })],
1370                        },
1371                        RuleAst {
1372                            id: RuleId::new(2),
1373                            head: atom(task_ready.clone(), &["x"]),
1374                            body: vec![
1375                                Literal::Positive(atom(task.clone(), &["x"])),
1376                                Literal::Negative(atom(task_complete.clone(), &["x"])),
1377                            ],
1378                        },
1379                    ],
1380                    materialized: vec![task_ready.id],
1381                    facts: vec![ExtensionalFact {
1382                        predicate: task,
1383                        values: vec![Value::Entity(aether_ast::EntityId::new(1))],
1384                        policy: None,
1385                        provenance: None,
1386                    }],
1387                },
1388            )
1389            .expect("compile stratified program");
1390
1391        assert_eq!(
1392            compiled.predicate_strata.get(&task_complete.id).copied(),
1393            Some(0)
1394        );
1395        assert_eq!(
1396            compiled.predicate_strata.get(&task_ready.id).copied(),
1397            Some(1)
1398        );
1399        assert_eq!(compiled.facts.len(), 1);
1400    }
1401}