aether_rules/
parser.rs

1use aether_ast::{
2    AggregateFunction, AggregateTerm, Atom, AttributeId, ExplainSpec, ExplainTarget,
3    ExtensionalFact, Literal, NamedExplainSpec, NamedQuerySpec, PolicyEnvelope, PredicateId,
4    PredicateRef, QueryAst, QuerySpec, RuleAst, RuleId, RuleProgram, TemporalView, Term, Value,
5    Variable,
6};
7use aether_schema::{
8    AttributeClass, AttributeSchema, PredicateSignature, Schema, SchemaError, ValueType,
9};
10use indexmap::{IndexMap, IndexSet};
11use thiserror::Error;
12
13pub trait DslParser {
14    fn parse_document(&self, input: &str) -> Result<DslDocument, ParseError>;
15}
16
17#[derive(Clone, Debug, PartialEq)]
18pub struct DslDocument {
19    pub schema: Schema,
20    pub program: RuleProgram,
21    pub query: Option<QuerySpec>,
22    pub queries: Vec<NamedQuerySpec>,
23    pub explains: Vec<NamedExplainSpec>,
24}
25
26#[derive(Default)]
27pub struct DefaultDslParser;
28
29impl DslParser for DefaultDslParser {
30    fn parse_document(&self, input: &str) -> Result<DslDocument, ParseError> {
31        parse_document(input)
32    }
33}
34
35fn parse_document(input: &str) -> Result<DslDocument, ParseError> {
36    let sections = collect_sections(input)?;
37    let schema_section = single_section(&sections, "schema")?;
38    let predicates_section = single_section(&sections, "predicates")?;
39    let rules_section = single_section(&sections, "rules")?;
40    let facts_section = optional_single_section(&sections, "facts")?;
41    let materialize_section = optional_single_section(&sections, "materialize")?;
42    let query_sections = sections.get("query").map(Vec::as_slice).unwrap_or(&[]);
43    let explain_sections = sections.get("explain").map(Vec::as_slice).unwrap_or(&[]);
44
45    let mut schema = parse_schema_section(schema_section)?;
46    let predicate_refs = parse_predicates_section(predicates_section, &mut schema)?;
47    let facts = parse_facts_section(facts_section, &predicate_refs)?;
48    let rules = parse_rules_section(rules_section, &predicate_refs)?;
49    let materialized = parse_materialize_section(materialize_section, &predicate_refs)?;
50    let (query, queries) = parse_query_sections(query_sections, &predicate_refs)?;
51    let explains = parse_explain_sections(explain_sections, &predicate_refs)?;
52
53    Ok(DslDocument {
54        program: RuleProgram {
55            predicates: predicate_refs.values().cloned().collect(),
56            rules,
57            materialized,
58            facts,
59        },
60        schema,
61        query,
62        queries,
63        explains,
64    })
65}
66
67#[derive(Clone, Debug)]
68struct Section {
69    name: String,
70    argument: Option<String>,
71    line: usize,
72    entries: Vec<(usize, String)>,
73}
74
75fn collect_sections(input: &str) -> Result<IndexMap<String, Vec<Section>>, ParseError> {
76    let mut sections: IndexMap<String, Vec<Section>> = IndexMap::new();
77    let mut current: Option<Section> = None;
78
79    for (index, raw_line) in input.lines().enumerate() {
80        let line_number = index + 1;
81        let line = strip_comments(raw_line).trim().to_owned();
82        if line.is_empty() {
83            continue;
84        }
85
86        if let Some(section) = current.as_mut() {
87            if line == "}" {
88                let section = current.take().expect("section present");
89                if !is_repeatable_section(&section.name) && sections.contains_key(&section.name) {
90                    return Err(ParseError::DuplicateSection {
91                        line: section.line,
92                        section: section.name,
93                    });
94                }
95                sections
96                    .entry(section.name.clone())
97                    .or_default()
98                    .push(section);
99                continue;
100            }
101
102            section
103                .entries
104                .push((line_number, trim_statement(&line).to_owned()));
105            continue;
106        }
107
108        let Some(header) = line.strip_suffix('{') else {
109            return Err(ParseError::UnexpectedTopLevel {
110                line: line_number,
111                content: line,
112            });
113        };
114        let header = header.trim();
115        let (name, argument) = parse_section_header(line_number, header)?;
116        current = Some(Section {
117            name,
118            argument,
119            line: line_number,
120            entries: Vec::new(),
121        });
122    }
123
124    if let Some(section) = current {
125        return Err(ParseError::UnterminatedSection {
126            line: section.line,
127            section: section.name,
128        });
129    }
130
131    Ok(sections)
132}
133
134fn parse_section_header(line: usize, header: &str) -> Result<(String, Option<String>), ParseError> {
135    let mut parts = header.split_whitespace();
136    let name = parts
137        .next()
138        .ok_or_else(|| ParseError::InvalidSectionHeader {
139            line,
140            header: header.into(),
141        })?;
142    let name = match name {
143        "schema" | "predicates" | "rules" | "materialize" | "facts" | "query" | "explain" => {
144            name.to_owned()
145        }
146        "queries" => "query".to_owned(),
147        "explains" => "explain".to_owned(),
148        "materialized" => "materialize".to_owned(),
149        other => {
150            return Err(ParseError::UnknownSection {
151                line,
152                section: other.into(),
153            })
154        }
155    };
156    let argument = parts.next().map(ToOwned::to_owned);
157    if parts.next().is_some() {
158        return Err(ParseError::InvalidSectionHeader {
159            line,
160            header: header.into(),
161        });
162    }
163
164    Ok((name, argument))
165}
166
167fn is_repeatable_section(name: &str) -> bool {
168    matches!(name, "query" | "explain")
169}
170
171fn single_section<'a>(
172    sections: &'a IndexMap<String, Vec<Section>>,
173    name: &'static str,
174) -> Result<&'a Section, ParseError> {
175    let entries = sections.get(name).ok_or(ParseError::MissingSection(name))?;
176    entries.first().ok_or(ParseError::MissingSection(name))
177}
178
179fn optional_single_section<'a>(
180    sections: &'a IndexMap<String, Vec<Section>>,
181    name: &'static str,
182) -> Result<Option<&'a Section>, ParseError> {
183    Ok(sections.get(name).and_then(|entries| entries.first()))
184}
185
186fn parse_schema_section(section: &Section) -> Result<Schema, ParseError> {
187    let mut schema = Schema::new(section.argument.as_deref().unwrap_or("v1"));
188    let mut next_attribute_id = 1u64;
189
190    for (line, entry) in &section.entries {
191        if entry.is_empty() {
192            continue;
193        }
194
195        let Some(entry) = entry.strip_prefix("attr ") else {
196            return Err(ParseError::InvalidSchemaEntry {
197                line: *line,
198                entry: entry.clone(),
199            });
200        };
201        let Some((name, spec)) = entry.split_once(':') else {
202            return Err(ParseError::InvalidSchemaEntry {
203                line: *line,
204                entry: entry.into(),
205            });
206        };
207        let name = name.trim();
208        let spec = spec.trim();
209        let (class, value_type) = parse_attribute_spec(*line, spec)?;
210        schema
211            .register_attribute(AttributeSchema {
212                id: AttributeId::new(next_attribute_id),
213                name: name.into(),
214                class,
215                value_type,
216            })
217            .map_err(|source| ParseError::Schema {
218                line: *line,
219                source,
220            })?;
221        next_attribute_id += 1;
222    }
223
224    Ok(schema)
225}
226
227fn parse_predicates_section(
228    section: &Section,
229    schema: &mut Schema,
230) -> Result<IndexMap<String, PredicateRef>, ParseError> {
231    let mut predicate_refs = IndexMap::new();
232    let mut next_predicate_id = 1u64;
233
234    for (line, entry) in &section.entries {
235        if entry.is_empty() {
236            continue;
237        }
238
239        let (name, args) = parse_call(*line, entry)?;
240        let fields = args
241            .iter()
242            .map(|token| parse_value_type(*line, token))
243            .collect::<Result<Vec<_>, _>>()?;
244        let id = PredicateId::new(next_predicate_id);
245        schema
246            .register_predicate(PredicateSignature {
247                id,
248                name: name.into(),
249                fields: fields.clone(),
250            })
251            .map_err(|source| ParseError::Schema {
252                line: *line,
253                source,
254            })?;
255        predicate_refs.insert(
256            name.into(),
257            PredicateRef {
258                id,
259                name: name.into(),
260                arity: fields.len(),
261            },
262        );
263        next_predicate_id += 1;
264    }
265
266    Ok(predicate_refs)
267}
268
269fn parse_rules_section(
270    section: &Section,
271    predicate_refs: &IndexMap<String, PredicateRef>,
272) -> Result<Vec<RuleAst>, ParseError> {
273    let mut rules = Vec::new();
274
275    for (index, (line, entry)) in section.entries.iter().enumerate() {
276        if entry.is_empty() {
277            continue;
278        }
279
280        let Some((head, body)) = entry.split_once("<-") else {
281            return Err(ParseError::InvalidRule {
282                line: *line,
283                rule: entry.clone(),
284            });
285        };
286        let head = parse_atom(*line, head.trim(), predicate_refs, true)?;
287        let body = split_top_level(body.trim(), ',', *line)?
288            .into_iter()
289            .filter(|literal| !literal.is_empty())
290            .map(|literal| parse_literal(*line, &literal, predicate_refs))
291            .collect::<Result<Vec<_>, _>>()?;
292
293        rules.push(RuleAst {
294            id: RuleId::new(index as u64 + 1),
295            head,
296            body,
297        });
298    }
299
300    Ok(rules)
301}
302
303fn parse_facts_section(
304    section: Option<&Section>,
305    predicate_refs: &IndexMap<String, PredicateRef>,
306) -> Result<Vec<ExtensionalFact>, ParseError> {
307    let Some(section) = section else {
308        return Ok(Vec::new());
309    };
310
311    let mut facts = Vec::new();
312    for (line, entry) in &section.entries {
313        if entry.is_empty() {
314            continue;
315        }
316
317        let (call, annotation_text) = split_call_and_suffix(*line, entry)?;
318        let (name, args) = parse_call(*line, call)?;
319        let predicate =
320            predicate_refs
321                .get(name)
322                .ok_or_else(|| ParseError::UnknownPredicateName {
323                    line: *line,
324                    name: name.into(),
325                })?;
326        if predicate.arity != args.len() {
327            return Err(ParseError::PredicateArityMismatch {
328                line: *line,
329                predicate: predicate.name.clone(),
330                expected: predicate.arity,
331                actual: args.len(),
332            });
333        }
334
335        let values = args
336            .iter()
337            .map(|token| parse_fact_value(*line, token))
338            .collect::<Result<Vec<_>, _>>()?;
339        let policy = parse_policy_annotations(*line, annotation_text)?;
340
341        facts.push(ExtensionalFact {
342            predicate: predicate.clone(),
343            values,
344            policy,
345            provenance: None,
346        });
347    }
348
349    Ok(facts)
350}
351
352fn parse_materialize_section(
353    section: Option<&Section>,
354    predicate_refs: &IndexMap<String, PredicateRef>,
355) -> Result<Vec<PredicateId>, ParseError> {
356    let Some(section) = section else {
357        return Ok(Vec::new());
358    };
359
360    let mut seen = IndexSet::new();
361    let mut materialized = Vec::new();
362
363    for (line, entry) in &section.entries {
364        for name in split_top_level(entry, ',', *line)? {
365            if name.is_empty() {
366                continue;
367            }
368            let predicate = predicate_refs.get(name.as_str()).ok_or_else(|| {
369                ParseError::UnknownPredicateName {
370                    line: *line,
371                    name: name.clone(),
372                }
373            })?;
374            if !seen.insert(predicate.id) {
375                return Err(ParseError::DuplicateMaterializedPredicate { line: *line, name });
376            }
377            materialized.push(predicate.id);
378        }
379    }
380
381    Ok(materialized)
382}
383
384fn parse_query_sections(
385    sections: &[Section],
386    predicate_refs: &IndexMap<String, PredicateRef>,
387) -> Result<(Option<QuerySpec>, Vec<NamedQuerySpec>), ParseError> {
388    let mut named_queries = Vec::new();
389    let mut seen_names = IndexSet::new();
390    let mut primary_query = None;
391
392    for section in sections {
393        let query = parse_single_query_section(section, predicate_refs)?;
394        if !seen_names.insert(section.argument.clone()) {
395            return Err(ParseError::DuplicateNamedSection {
396                line: section.line,
397                section: "query".into(),
398                name: section.argument.clone(),
399            });
400        }
401        if primary_query.is_none() || section.argument.is_none() {
402            primary_query = Some(query.clone());
403        }
404        named_queries.push(NamedQuerySpec {
405            name: section.argument.clone(),
406            spec: query,
407        });
408    }
409
410    Ok((primary_query, named_queries))
411}
412
413fn parse_single_query_section(
414    section: &Section,
415    predicate_refs: &IndexMap<String, PredicateRef>,
416) -> Result<QuerySpec, ParseError> {
417    let mut view = TemporalView::Current;
418    let mut goals = Vec::new();
419    let mut keep = Vec::new();
420
421    for (line, entry) in &section.entries {
422        if let Some(rest) = entry.strip_prefix("as_of ") {
423            let Some(element) = rest.trim().strip_prefix('e') else {
424                return Err(ParseError::InvalidQueryEntry {
425                    line: *line,
426                    entry: entry.clone(),
427                });
428            };
429            let element = element
430                .parse::<u64>()
431                .map_err(|_| ParseError::InvalidQueryEntry {
432                    line: *line,
433                    entry: entry.clone(),
434                })?;
435            view = TemporalView::AsOf(aether_ast::ElementId::new(element));
436            continue;
437        }
438        if entry == "current" {
439            view = TemporalView::Current;
440            continue;
441        }
442        if let Some(rest) = entry
443            .strip_prefix("goal ")
444            .or_else(|| entry.strip_prefix("find "))
445        {
446            goals.push(parse_atom(*line, rest.trim(), predicate_refs, false)?);
447            continue;
448        }
449        if let Some(rest) = entry.strip_prefix("keep ") {
450            keep.extend(
451                split_top_level(rest.trim(), ',', *line)?
452                    .into_iter()
453                    .filter(|name| !name.is_empty())
454                    .map(Variable::new),
455            );
456            continue;
457        }
458
459        return Err(ParseError::InvalidQueryEntry {
460            line: *line,
461            entry: entry.clone(),
462        });
463    }
464
465    Ok(QuerySpec {
466        view,
467        query: QueryAst { goals, keep },
468    })
469}
470
471fn parse_explain_sections(
472    sections: &[Section],
473    predicate_refs: &IndexMap<String, PredicateRef>,
474) -> Result<Vec<NamedExplainSpec>, ParseError> {
475    let mut explains = Vec::new();
476    let mut seen_names = IndexSet::new();
477
478    for section in sections {
479        if !seen_names.insert(section.argument.clone()) {
480            return Err(ParseError::DuplicateNamedSection {
481                line: section.line,
482                section: "explain".into(),
483                name: section.argument.clone(),
484            });
485        }
486        explains.push(NamedExplainSpec {
487            name: section.argument.clone(),
488            spec: parse_single_explain_section(section, predicate_refs)?,
489        });
490    }
491
492    Ok(explains)
493}
494
495fn parse_single_explain_section(
496    section: &Section,
497    predicate_refs: &IndexMap<String, PredicateRef>,
498) -> Result<ExplainSpec, ParseError> {
499    let mut view = TemporalView::Current;
500    let mut target = None;
501
502    for (line, entry) in &section.entries {
503        if let Some(rest) = entry.strip_prefix("as_of ") {
504            let Some(element) = rest.trim().strip_prefix('e') else {
505                return Err(ParseError::InvalidExplainEntry {
506                    line: *line,
507                    entry: entry.clone(),
508                });
509            };
510            let element = element
511                .parse::<u64>()
512                .map_err(|_| ParseError::InvalidExplainEntry {
513                    line: *line,
514                    entry: entry.clone(),
515                })?;
516            view = TemporalView::AsOf(aether_ast::ElementId::new(element));
517            continue;
518        }
519        if entry == "current" {
520            view = TemporalView::Current;
521            continue;
522        }
523        if entry == "plan" {
524            if target.replace(ExplainTarget::Plan).is_some() {
525                return Err(ParseError::InvalidExplainEntry {
526                    line: *line,
527                    entry: entry.clone(),
528                });
529            }
530            continue;
531        }
532        if let Some(rest) = entry.strip_prefix("tuple ") {
533            let atom = parse_atom(*line, rest.trim(), predicate_refs, false)?;
534            if !atom.terms.iter().all(|term| matches!(term, Term::Value(_))) {
535                return Err(ParseError::NonGroundExplainTuple {
536                    line: *line,
537                    entry: entry.clone(),
538                });
539            }
540            if target.replace(ExplainTarget::Tuple(atom)).is_some() {
541                return Err(ParseError::InvalidExplainEntry {
542                    line: *line,
543                    entry: entry.clone(),
544                });
545            }
546            continue;
547        }
548
549        return Err(ParseError::InvalidExplainEntry {
550            line: *line,
551            entry: entry.clone(),
552        });
553    }
554
555    let target = target.ok_or(ParseError::MissingExplainTarget {
556        line: section.line,
557        name: section.argument.clone(),
558    })?;
559    Ok(ExplainSpec { view, target })
560}
561
562fn parse_attribute_spec(
563    line: usize,
564    spec: &str,
565) -> Result<(AttributeClass, ValueType), ParseError> {
566    let Some(open) = spec.find('<') else {
567        return Err(ParseError::InvalidSchemaEntry {
568            line,
569            entry: spec.into(),
570        });
571    };
572    let Some(close) = spec.rfind('>') else {
573        return Err(ParseError::InvalidSchemaEntry {
574            line,
575            entry: spec.into(),
576        });
577    };
578    if close + 1 != spec.len() {
579        return Err(ParseError::InvalidSchemaEntry {
580            line,
581            entry: spec.into(),
582        });
583    }
584
585    let class = match spec[..open].trim() {
586        "ScalarLww" | "ScalarLWW" => AttributeClass::ScalarLww,
587        "SetAddWins" => AttributeClass::SetAddWins,
588        "SequenceRga" | "SequenceRGA" => AttributeClass::SequenceRga,
589        "RefScalar" => AttributeClass::RefScalar,
590        "RefSet" => AttributeClass::RefSet,
591        _ => {
592            return Err(ParseError::InvalidSchemaEntry {
593                line,
594                entry: spec.into(),
595            })
596        }
597    };
598
599    let value_type = parse_value_type(line, &spec[open + 1..close])?;
600    Ok((class, value_type))
601}
602
603fn parse_literal(
604    line: usize,
605    literal: &str,
606    predicate_refs: &IndexMap<String, PredicateRef>,
607) -> Result<Literal, ParseError> {
608    if let Some(atom) = literal.strip_prefix("not ") {
609        return Ok(Literal::Negative(parse_atom(
610            line,
611            atom.trim(),
612            predicate_refs,
613            false,
614        )?));
615    }
616    if let Some(atom) = literal.strip_prefix('!') {
617        return Ok(Literal::Negative(parse_atom(
618            line,
619            atom.trim(),
620            predicate_refs,
621            false,
622        )?));
623    }
624
625    Ok(Literal::Positive(parse_atom(
626        line,
627        literal.trim(),
628        predicate_refs,
629        false,
630    )?))
631}
632
633fn parse_atom(
634    line: usize,
635    atom: &str,
636    predicate_refs: &IndexMap<String, PredicateRef>,
637    allow_aggregates: bool,
638) -> Result<Atom, ParseError> {
639    let (name, args) = parse_call(line, atom)?;
640    let predicate = predicate_refs
641        .get(name)
642        .ok_or_else(|| ParseError::UnknownPredicateName {
643            line,
644            name: name.into(),
645        })?;
646    if predicate.arity != args.len() {
647        return Err(ParseError::PredicateArityMismatch {
648            line,
649            predicate: predicate.name.clone(),
650            expected: predicate.arity,
651            actual: args.len(),
652        });
653    }
654
655    let terms = args
656        .iter()
657        .map(|token| parse_term(line, token, allow_aggregates))
658        .collect::<Result<Vec<_>, _>>()?;
659
660    Ok(Atom {
661        predicate: predicate.clone(),
662        terms,
663    })
664}
665
666fn parse_call(line: usize, text: &str) -> Result<(&str, Vec<String>), ParseError> {
667    let Some(open) = text.find('(') else {
668        return Err(ParseError::InvalidCall {
669            line,
670            text: text.into(),
671        });
672    };
673    let Some(close) = text.rfind(')') else {
674        return Err(ParseError::InvalidCall {
675            line,
676            text: text.into(),
677        });
678    };
679    if close + 1 != text.len() {
680        return Err(ParseError::InvalidCall {
681            line,
682            text: text.into(),
683        });
684    }
685
686    let name = text[..open].trim();
687    if name.is_empty() {
688        return Err(ParseError::InvalidCall {
689            line,
690            text: text.into(),
691        });
692    }
693
694    let inner = text[open + 1..close].trim();
695    let args = if inner.is_empty() {
696        Vec::new()
697    } else {
698        split_top_level(inner, ',', line)?
699    };
700
701    Ok((name, args))
702}
703
704fn parse_term(line: usize, token: &str, allow_aggregates: bool) -> Result<Term, ParseError> {
705    let token = token.trim();
706    if token.is_empty() {
707        return Err(ParseError::InvalidTerm {
708            line,
709            text: token.into(),
710        });
711    }
712
713    if token.starts_with('"') {
714        return Ok(Term::Value(Value::String(parse_string_literal(
715            line, token,
716        )?)));
717    }
718
719    if let Some(inner) = token
720        .strip_prefix("entity(")
721        .and_then(|rest| rest.strip_suffix(')'))
722    {
723        let value = inner
724            .trim()
725            .parse::<u64>()
726            .map_err(|_| ParseError::InvalidTerm {
727                line,
728                text: token.into(),
729            })?;
730        return Ok(Term::Value(Value::Entity(aether_ast::EntityId::new(value))));
731    }
732
733    if token.contains('(') || token.contains(')') {
734        if allow_aggregates {
735            if let Some(aggregate) = parse_aggregate_term(line, token)? {
736                return Ok(Term::Aggregate(aggregate));
737            }
738        }
739        return Err(ParseError::InvalidTerm {
740            line,
741            text: token.into(),
742        });
743    }
744
745    match token {
746        "true" => return Ok(Term::Value(Value::Bool(true))),
747        "false" => return Ok(Term::Value(Value::Bool(false))),
748        "null" => return Ok(Term::Value(Value::Null)),
749        _ => {}
750    }
751
752    if let Ok(value) = token.parse::<i64>() {
753        if token.starts_with('-') {
754            return Ok(Term::Value(Value::I64(value)));
755        }
756    }
757    if let Ok(value) = token.parse::<u64>() {
758        return Ok(Term::Value(Value::U64(value)));
759    }
760    if token.contains('.') {
761        if let Ok(value) = token.parse::<f64>() {
762            return Ok(Term::Value(Value::F64(value)));
763        }
764    }
765
766    Ok(Term::Variable(Variable::new(token)))
767}
768
769fn parse_fact_value(line: usize, token: &str) -> Result<Value, ParseError> {
770    let token = token.trim();
771    match parse_term(line, token, false)? {
772        Term::Value(value) => Ok(value),
773        Term::Variable(_) | Term::Aggregate(_) => Err(ParseError::InvalidFactValue {
774            line,
775            text: token.into(),
776        }),
777    }
778}
779
780fn parse_aggregate_term(line: usize, token: &str) -> Result<Option<AggregateTerm>, ParseError> {
781    let Ok((name, args)) = parse_call(line, token) else {
782        return Ok(None);
783    };
784
785    let function = match name {
786        "count" => AggregateFunction::Count,
787        "sum" => AggregateFunction::Sum,
788        "min" => AggregateFunction::Min,
789        "max" => AggregateFunction::Max,
790        _ => return Ok(None),
791    };
792
793    if args.len() != 1 {
794        return Err(ParseError::InvalidTerm {
795            line,
796            text: token.into(),
797        });
798    }
799
800    let variable = args[0].trim();
801    if variable.is_empty()
802        || variable.starts_with('"')
803        || variable.contains('(')
804        || variable.contains(')')
805    {
806        return Err(ParseError::InvalidTerm {
807            line,
808            text: token.into(),
809        });
810    }
811
812    Ok(Some(AggregateTerm {
813        function,
814        variable: Variable::new(variable),
815    }))
816}
817
818fn parse_string_literal(line: usize, token: &str) -> Result<String, ParseError> {
819    if token.len() < 2 || !token.ends_with('"') {
820        return Err(ParseError::InvalidTerm {
821            line,
822            text: token.into(),
823        });
824    }
825
826    let mut result = String::new();
827    let mut chars = token[1..token.len() - 1].chars();
828    while let Some(ch) = chars.next() {
829        if ch == '\\' {
830            let Some(escaped) = chars.next() else {
831                return Err(ParseError::InvalidTerm {
832                    line,
833                    text: token.into(),
834                });
835            };
836            match escaped {
837                '\\' => result.push('\\'),
838                '"' => result.push('"'),
839                'n' => result.push('\n'),
840                'r' => result.push('\r'),
841                't' => result.push('\t'),
842                _ => {
843                    return Err(ParseError::InvalidTerm {
844                        line,
845                        text: token.into(),
846                    })
847                }
848            }
849        } else {
850            result.push(ch);
851        }
852    }
853
854    Ok(result)
855}
856
857fn parse_value_type(line: usize, token: &str) -> Result<ValueType, ParseError> {
858    let token = token.trim();
859    if let Some(inner) = token
860        .strip_prefix("List<")
861        .and_then(|rest| rest.strip_suffix('>'))
862    {
863        return Ok(ValueType::List(Box::new(parse_value_type(line, inner)?)));
864    }
865
866    match token {
867        "Bool" => Ok(ValueType::Bool),
868        "I64" => Ok(ValueType::I64),
869        "U64" => Ok(ValueType::U64),
870        "F64" => Ok(ValueType::F64),
871        "String" => Ok(ValueType::String),
872        "Bytes" => Ok(ValueType::Bytes),
873        "Entity" => Ok(ValueType::Entity),
874        _ => Err(ParseError::InvalidType {
875            line,
876            text: token.into(),
877        }),
878    }
879}
880
881fn split_top_level(input: &str, separator: char, line: usize) -> Result<Vec<String>, ParseError> {
882    let mut parts = Vec::new();
883    let mut current = String::new();
884    let mut paren_depth = 0usize;
885    let mut angle_depth = 0usize;
886    let mut in_string = false;
887    let mut chars = input.chars().peekable();
888
889    while let Some(ch) = chars.next() {
890        match ch {
891            '"' => {
892                in_string = !in_string;
893                current.push(ch);
894            }
895            '\\' if in_string => {
896                current.push(ch);
897                if let Some(next) = chars.next() {
898                    current.push(next);
899                }
900            }
901            '(' if !in_string => {
902                paren_depth += 1;
903                current.push(ch);
904            }
905            ')' if !in_string => {
906                if paren_depth == 0 {
907                    return Err(ParseError::UnbalancedDelimiter { line });
908                }
909                paren_depth -= 1;
910                current.push(ch);
911            }
912            '<' if !in_string => {
913                angle_depth += 1;
914                current.push(ch);
915            }
916            '>' if !in_string => {
917                if angle_depth == 0 {
918                    return Err(ParseError::UnbalancedDelimiter { line });
919                }
920                angle_depth -= 1;
921                current.push(ch);
922            }
923            _ if ch == separator && !in_string && paren_depth == 0 && angle_depth == 0 => {
924                parts.push(current.trim().to_owned());
925                current.clear();
926            }
927            _ => current.push(ch),
928        }
929    }
930
931    if in_string || paren_depth != 0 || angle_depth != 0 {
932        return Err(ParseError::UnbalancedDelimiter { line });
933    }
934
935    parts.push(current.trim().to_owned());
936    Ok(parts)
937}
938
939fn strip_comments(line: &str) -> &str {
940    let mut in_string = false;
941    let mut escaped = false;
942
943    for (index, ch) in line.char_indices() {
944        match ch {
945            '"' if !escaped => in_string = !in_string,
946            '#' if !in_string => return &line[..index],
947            '\\' if in_string => {
948                escaped = !escaped;
949                continue;
950            }
951            _ => {}
952        }
953        escaped = false;
954    }
955
956    line
957}
958
959fn trim_statement(line: &str) -> &str {
960    line.trim_end_matches([';', '.'])
961}
962
963fn split_call_and_suffix(line: usize, entry: &str) -> Result<(&str, &str), ParseError> {
964    let mut in_string = false;
965    let mut paren_depth = 0usize;
966
967    for (index, ch) in entry.char_indices() {
968        match ch {
969            '"' => in_string = !in_string,
970            '(' if !in_string => paren_depth += 1,
971            ')' if !in_string => {
972                if paren_depth == 0 {
973                    return Err(ParseError::UnbalancedDelimiter { line });
974                }
975                paren_depth -= 1;
976                if paren_depth == 0 {
977                    return Ok((&entry[..=index], entry[index + 1..].trim()));
978                }
979            }
980            _ => {}
981        }
982    }
983
984    Err(ParseError::InvalidCall {
985        line,
986        text: entry.into(),
987    })
988}
989
990fn parse_policy_annotations(line: usize, rest: &str) -> Result<Option<PolicyEnvelope>, ParseError> {
991    if rest.is_empty() {
992        return Ok(None);
993    }
994
995    let mut capability = None;
996    let mut visibility = None;
997    let mut remaining = rest.trim();
998
999    while !remaining.is_empty() {
1000        let Some(annotation) = remaining.strip_prefix('@') else {
1001            return Err(ParseError::InvalidPolicyAnnotation {
1002                line,
1003                text: rest.into(),
1004            });
1005        };
1006        let (call, suffix) = split_call_and_suffix(line, annotation)?;
1007        let (name, args) = parse_call(line, call)?;
1008        if args.len() != 1 {
1009            return Err(ParseError::InvalidPolicyAnnotation {
1010                line,
1011                text: call.into(),
1012            });
1013        }
1014        let value = parse_fact_value(line, &args[0])?;
1015        let Value::String(value) = value else {
1016            return Err(ParseError::InvalidPolicyAnnotation {
1017                line,
1018                text: call.into(),
1019            });
1020        };
1021
1022        match name {
1023            "capability" => capability = Some(value),
1024            "visibility" => visibility = Some(value),
1025            _ => {
1026                return Err(ParseError::InvalidPolicyAnnotation {
1027                    line,
1028                    text: call.into(),
1029                })
1030            }
1031        }
1032
1033        remaining = suffix.trim();
1034    }
1035
1036    Ok(Some(
1037        PolicyEnvelope {
1038            capabilities: capability.into_iter().collect(),
1039            visibilities: visibility.into_iter().collect(),
1040        }
1041        .normalized(),
1042    ))
1043}
1044
1045#[derive(Debug, Error)]
1046pub enum ParseError {
1047    #[error("line {line}: unexpected top-level content: {content}")]
1048    UnexpectedTopLevel { line: usize, content: String },
1049    #[error("line {line}: invalid section header {header}")]
1050    InvalidSectionHeader { line: usize, header: String },
1051    #[error("line {line}: unknown section {section}")]
1052    UnknownSection { line: usize, section: String },
1053    #[error("line {line}: duplicate section {section}")]
1054    DuplicateSection { line: usize, section: String },
1055    #[error("section {0} is required")]
1056    MissingSection(&'static str),
1057    #[error("line {line}: section {section} is missing a closing brace")]
1058    UnterminatedSection { line: usize, section: String },
1059    #[error("line {line}: invalid schema entry {entry}")]
1060    InvalidSchemaEntry { line: usize, entry: String },
1061    #[error("line {line}: invalid predicate or atom call {text}")]
1062    InvalidCall { line: usize, text: String },
1063    #[error("line {line}: invalid rule {rule}")]
1064    InvalidRule { line: usize, rule: String },
1065    #[error("line {line}: unknown predicate {name}")]
1066    UnknownPredicateName { line: usize, name: String },
1067    #[error("line {line}: predicate {predicate} has arity {actual}, expected {expected}")]
1068    PredicateArityMismatch {
1069        line: usize,
1070        predicate: String,
1071        expected: usize,
1072        actual: usize,
1073    },
1074    #[error("line {line}: invalid type {text}")]
1075    InvalidType { line: usize, text: String },
1076    #[error("line {line}: invalid term {text}")]
1077    InvalidTerm { line: usize, text: String },
1078    #[error("line {line}: invalid fact value {text}")]
1079    InvalidFactValue { line: usize, text: String },
1080    #[error("line {line}: invalid query entry {entry}")]
1081    InvalidQueryEntry { line: usize, entry: String },
1082    #[error("line {line}: invalid explain entry {entry}")]
1083    InvalidExplainEntry { line: usize, entry: String },
1084    #[error("line {line}: invalid policy annotation {text}")]
1085    InvalidPolicyAnnotation { line: usize, text: String },
1086    #[error("line {line}: duplicate materialized predicate {name}")]
1087    DuplicateMaterializedPredicate { line: usize, name: String },
1088    #[error("line {line}: duplicate {section} section name {name:?}")]
1089    DuplicateNamedSection {
1090        line: usize,
1091        section: String,
1092        name: Option<String>,
1093    },
1094    #[error("line {line}: explain section {name:?} does not declare a target")]
1095    MissingExplainTarget { line: usize, name: Option<String> },
1096    #[error("line {line}: explain tuple must be ground: {entry}")]
1097    NonGroundExplainTuple { line: usize, entry: String },
1098    #[error("line {line}: unbalanced delimiter in DSL input")]
1099    UnbalancedDelimiter { line: usize },
1100    #[error("line {line}: schema error: {source}")]
1101    Schema { line: usize, source: SchemaError },
1102}
1103
1104#[cfg(test)]
1105mod tests {
1106    use super::{DefaultDslParser, DslParser, ParseError};
1107    use crate::{DefaultRuleCompiler, RuleCompiler};
1108    use aether_ast::{
1109        AggregateFunction, Atom, ExplainTarget, Literal, NamedQuerySpec, PredicateId, PredicateRef,
1110        QueryAst, QuerySpec, TemporalView, Term, Value,
1111    };
1112    use aether_schema::{AttributeClass, ValueType};
1113
1114    #[test]
1115    fn parses_document_and_compiles_recursive_program() {
1116        let document = DefaultDslParser
1117            .parse_document(
1118                r#"
1119                schema v1 {
1120                  attr task.depends_on: RefSet<Entity>
1121                  attr task.labels: SetAddWins<String>
1122                }
1123
1124                predicates {
1125                  task_depends_on(Entity, Entity)
1126                  depends_transitive(Entity, Entity)
1127                }
1128
1129                rules {
1130                  depends_transitive(x, y) <- task_depends_on(x, y)
1131                  depends_transitive(x, z) <- depends_transitive(x, y), task_depends_on(y, z)
1132                }
1133
1134                materialize {
1135                  depends_transitive
1136                }
1137                "#,
1138            )
1139            .expect("parse dsl document");
1140
1141        assert_eq!(document.schema.version, "v1");
1142        assert_eq!(
1143            document
1144                .schema
1145                .attribute(&aether_ast::AttributeId::new(1))
1146                .expect("first attribute")
1147                .class,
1148            AttributeClass::RefSet
1149        );
1150        assert_eq!(
1151            document
1152                .schema
1153                .predicate(&PredicateId::new(1))
1154                .expect("first predicate")
1155                .fields,
1156            vec![ValueType::Entity, ValueType::Entity]
1157        );
1158        assert_eq!(document.program.rules.len(), 2);
1159        assert_eq!(document.program.materialized, vec![PredicateId::new(2)]);
1160        assert!(document.query.is_none());
1161        assert!(document.queries.is_empty());
1162        assert!(document.explains.is_empty());
1163
1164        let compiled = DefaultRuleCompiler
1165            .compile(&document.schema, &document.program)
1166            .expect("compile parsed program");
1167        assert_eq!(
1168            compiled.extensional_bindings.get(&PredicateId::new(1)),
1169            Some(&aether_ast::AttributeId::new(1))
1170        );
1171    }
1172
1173    #[test]
1174    fn parses_negation_and_constant_terms() {
1175        let document = DefaultDslParser
1176            .parse_document(
1177                r#"
1178                schema {
1179                  attr task.status: ScalarLWW<String>
1180                }
1181
1182                predicates {
1183                  task_status(Entity, String)
1184                  task(Entity)
1185                  blocked(Entity)
1186                }
1187
1188                rules {
1189                  blocked(x) <- task(x), not task_status(x, "ready"), task_status(x, "blocked"), task_status(x, "retry-1")
1190                }
1191
1192                materialize {
1193                  blocked
1194                }
1195                "#,
1196            )
1197            .expect("parse rule with negation and constants");
1198
1199        let rule = &document.program.rules[0];
1200        assert!(matches!(rule.body[1], Literal::Negative(_)));
1201        let Literal::Positive(atom) = &rule.body[2] else {
1202            panic!("expected positive literal");
1203        };
1204        assert_eq!(
1205            atom.terms,
1206            vec![
1207                Term::Variable(aether_ast::Variable::new("x")),
1208                Term::Value(Value::String("blocked".into())),
1209            ]
1210        );
1211    }
1212
1213    #[test]
1214    fn parses_facts_queries_as_of_and_policy_annotations() {
1215        let document = DefaultDslParser
1216            .parse_document(
1217                r#"
1218                schema v2 {
1219                  attr task.status: ScalarLWW<String>
1220                }
1221
1222                predicates {
1223                  execution_attempt(Entity, String, U64)
1224                  task_ready(Entity)
1225                }
1226
1227                facts {
1228                  execution_attempt(entity(1), "worker-a", 1) @capability("executor") @visibility("ops")
1229                }
1230
1231                rules {
1232                  task_ready(x) <- task_ready(x)
1233                }
1234
1235                query {
1236                  as_of e5
1237                  goal task_ready(x)
1238                  keep x
1239                }
1240                "#,
1241            )
1242            .expect("parse document with facts and query");
1243
1244        assert_eq!(document.schema.version, "v2");
1245        assert_eq!(document.program.facts.len(), 1);
1246        assert_eq!(
1247            document.program.facts[0].policy,
1248            Some(aether_ast::PolicyEnvelope {
1249                capabilities: vec!["executor".into()],
1250                visibilities: vec!["ops".into()],
1251            })
1252        );
1253        assert_eq!(
1254            document.query,
1255            Some(QuerySpec {
1256                view: TemporalView::AsOf(aether_ast::ElementId::new(5)),
1257                query: QueryAst {
1258                    goals: vec![Atom {
1259                        predicate: PredicateRef {
1260                            id: PredicateId::new(2),
1261                            name: "task_ready".into(),
1262                            arity: 1,
1263                        },
1264                        terms: vec![Term::Variable(aether_ast::Variable::new("x"))],
1265                    }],
1266                    keep: vec![aether_ast::Variable::new("x")],
1267                },
1268            })
1269        );
1270        assert_eq!(
1271            document.queries,
1272            vec![NamedQuerySpec {
1273                name: None,
1274                spec: document.query.clone().expect("primary query"),
1275            }]
1276        );
1277        assert!(document.explains.is_empty());
1278    }
1279
1280    #[test]
1281    fn parses_head_aggregates_for_bounded_aggregation_rules() {
1282        let document = DefaultDslParser
1283            .parse_document(
1284                r#"
1285                schema {
1286                  attr task.depends_on: RefSet<Entity>
1287                }
1288
1289                predicates {
1290                  task_depends_on(Entity, Entity)
1291                  dependency_count(Entity, U64)
1292                }
1293
1294                rules {
1295                  dependency_count(task, count(dep)) <- task_depends_on(task, dep)
1296                }
1297
1298                materialize {
1299                  dependency_count
1300                }
1301                "#,
1302            )
1303            .expect("parse aggregate rule");
1304
1305        let aggregate_rule = &document.program.rules[0];
1306        assert!(matches!(
1307            &aggregate_rule.head.terms[1],
1308            Term::Aggregate(aggregate)
1309                if aggregate.function == AggregateFunction::Count
1310                    && aggregate.variable == aether_ast::Variable::new("dep")
1311        ));
1312    }
1313
1314    #[test]
1315    fn parses_multiple_head_aggregates_in_one_rule() {
1316        let document = DefaultDslParser
1317            .parse_document(
1318                r#"
1319                schema {
1320                  attr project.task: RefSet<Entity>
1321                  attr task.hours: ScalarLWW<U64>
1322                }
1323
1324                predicates {
1325                  project_task(Entity, Entity)
1326                  task_hours(Entity, U64)
1327                  project_stats(Entity, U64, U64)
1328                }
1329
1330                rules {
1331                  project_stats(project, count(task), sum(hours)) <- project_task(project, task), task_hours(task, hours)
1332                }
1333
1334                materialize {
1335                  project_stats
1336                }
1337                "#,
1338            )
1339            .expect("parse multi aggregate rule");
1340
1341        let aggregate_rule = &document.program.rules[0];
1342        assert!(matches!(
1343            &aggregate_rule.head.terms[1],
1344            Term::Aggregate(aggregate)
1345                if aggregate.function == AggregateFunction::Count
1346                    && aggregate.variable == aether_ast::Variable::new("task")
1347        ));
1348        assert!(matches!(
1349            &aggregate_rule.head.terms[2],
1350            Term::Aggregate(aggregate)
1351                if aggregate.function == AggregateFunction::Sum
1352                    && aggregate.variable == aether_ast::Variable::new("hours")
1353        ));
1354    }
1355
1356    #[test]
1357    fn parses_named_queries_and_explain_directives() {
1358        let document = DefaultDslParser
1359            .parse_document(
1360                r#"
1361                schema {
1362                  attr task.depends_on: RefSet<Entity>
1363                }
1364
1365                predicates {
1366                  task_depends_on(Entity, Entity)
1367                  depends_transitive(Entity, Entity)
1368                }
1369
1370                rules {
1371                  depends_transitive(x, y) <- task_depends_on(x, y)
1372                }
1373
1374                materialize {
1375                  depends_transitive
1376                }
1377
1378                query ready_now {
1379                  current
1380                  find depends_transitive(entity(1), y)
1381                  keep y
1382                }
1383
1384                query ready_then {
1385                  as_of e7
1386                  goal depends_transitive(entity(1), y)
1387                  keep y
1388                }
1389
1390                explain proof_now {
1391                  current
1392                  tuple depends_transitive(entity(1), entity(2))
1393                }
1394
1395                explain plan_view {
1396                  plan
1397                }
1398                "#,
1399            )
1400            .expect("parse named queries and explain directives");
1401
1402        assert_eq!(document.queries.len(), 2);
1403        assert_eq!(document.queries[0].name.as_deref(), Some("ready_now"));
1404        assert_eq!(document.queries[1].name.as_deref(), Some("ready_then"));
1405        assert_eq!(document.query, Some(document.queries[0].spec.clone()));
1406
1407        assert_eq!(document.explains.len(), 2);
1408        assert!(matches!(
1409            &document.explains[0].spec.target,
1410            ExplainTarget::Tuple(atom)
1411                if atom.terms
1412                    == vec![
1413                        Term::Value(Value::Entity(aether_ast::EntityId::new(1))),
1414                        Term::Value(Value::Entity(aether_ast::EntityId::new(2))),
1415                    ]
1416        ));
1417        assert_eq!(document.explains[1].name.as_deref(), Some("plan_view"));
1418        assert_eq!(document.explains[1].spec.target, ExplainTarget::Plan);
1419    }
1420
1421    #[test]
1422    fn rejects_unknown_predicates() {
1423        let error = DefaultDslParser
1424            .parse_document(
1425                r#"
1426                schema {
1427                  attr task.status: ScalarLWW<String>
1428                }
1429
1430                predicates {
1431                  task_status(Entity, String)
1432                }
1433
1434                rules {
1435                  blocked(x) <- task_status(x, "ready")
1436                }
1437                "#,
1438            )
1439            .expect_err("unknown rule head predicate should fail");
1440
1441        assert!(matches!(
1442            error,
1443            ParseError::UnknownPredicateName { name, .. } if name == "blocked"
1444        ));
1445    }
1446
1447    #[test]
1448    fn rejects_unknown_types_and_duplicate_materialize_entries() {
1449        let unknown_type = DefaultDslParser
1450            .parse_document(
1451                r#"
1452                schema {
1453                  attr task.owner: RefScalar<Task>
1454                }
1455
1456                predicates {
1457                  task_owner(Entity, Entity)
1458                }
1459
1460                rules {
1461                  task_owner(x, y) <- task_owner(x, y)
1462                }
1463                "#,
1464            )
1465            .expect_err("unknown type alias should fail");
1466        assert!(matches!(unknown_type, ParseError::InvalidType { text, .. } if text == "Task"));
1467
1468        let duplicate_materialize = DefaultDslParser
1469            .parse_document(
1470                r#"
1471                schema {
1472                  attr task.status: ScalarLWW<String>
1473                }
1474
1475                predicates {
1476                  task_status(Entity, String)
1477                }
1478
1479                rules {
1480                  task_status(x, s) <- task_status(x, s)
1481                }
1482
1483                materialize {
1484                  task_status
1485                  task_status
1486                }
1487                "#,
1488            )
1489            .expect_err("duplicate materialize should fail");
1490        assert!(matches!(
1491            duplicate_materialize,
1492            ParseError::DuplicateMaterializedPredicate { name, .. } if name == "task_status"
1493        ));
1494    }
1495}