Skip to main content

gram_codec/cst/
parser.rs

1//! tree-sitter-backed CST parsing entry points.
2
3use crate::cst::{Annotation, ArrowKind, CstParseResult, SourceSpan, SyntaxKind, SyntaxNode};
4use crate::{Pattern, Subject, Value};
5use std::collections::HashSet;
6use tree_sitter::{Node, Parser};
7
8pub fn parse_gram_cst(input: &str) -> CstParseResult {
9    let mut parser = Parser::new();
10    let mut errors = Vec::new();
11    if parser
12        .set_language(&tree_sitter_gram::LANGUAGE.into())
13        .is_err()
14    {
15        errors.push(whole_input_span(input));
16        return CstParseResult {
17            tree: document_tree(input, whole_input_span(input), None, vec![]),
18            errors,
19        };
20    }
21
22    let Some(tree) = parser.parse(input, None) else {
23        errors.push(whole_input_span(input));
24        return CstParseResult {
25            tree: document_tree(input, whole_input_span(input), None, vec![]),
26            errors,
27        };
28    };
29
30    let root = tree.root_node();
31    if root.kind() != "gram_pattern" {
32        record_error(&mut errors, root);
33    }
34
35    let mut elements = Vec::new();
36    let mut cursor = root.walk();
37    for child in root.children(&mut cursor) {
38        if !child.is_named() {
39            continue;
40        }
41
42        match child.kind() {
43            "record" => {}
44            "node_pattern"
45            | "relationship_pattern"
46            | "subject_pattern"
47            | "annotated_pattern"
48            | "comment" => {
49                if let Some(element) = convert_named_node(child, input, &mut errors) {
50                    elements.push(element);
51                }
52            }
53            _ => {}
54        }
55    }
56
57    errors.extend(collect_error_spans(root));
58    dedupe_errors(&mut errors);
59
60    CstParseResult {
61        tree: document_tree(
62            input,
63            span_from_node(root),
64            root.child_by_field_name("root")
65                .map(|record| extract_record_subject(record, input)),
66            elements,
67        ),
68        errors,
69    }
70}
71
72fn convert_named_node(
73    node: Node<'_>,
74    input: &str,
75    errors: &mut Vec<SourceSpan>,
76) -> Option<Pattern<SyntaxNode>> {
77    match node.kind() {
78        "node_pattern" => Some(convert_node_pattern(node, input)),
79        "relationship_pattern" => Some(convert_relationship_pattern(node, input, errors)),
80        "subject_pattern" => Some(convert_subject_pattern(node, input, errors)),
81        "annotated_pattern" => Some(convert_annotated_pattern(node, input, errors)),
82        "comment" => Some(convert_comment(node, input)),
83        "pattern_reference" => Some(convert_pattern_reference(node, input, errors)),
84        _ => {
85            record_error(errors, node);
86            None
87        }
88    }
89}
90
91fn convert_node_pattern(node: Node<'_>, input: &str) -> Pattern<SyntaxNode> {
92    Pattern::point(SyntaxNode {
93        kind: SyntaxKind::Node,
94        subject: extract_subject(node, input),
95        span: span_from_node(node),
96        annotations: vec![],
97        text: None,
98    })
99}
100
101fn convert_relationship_pattern(
102    node: Node<'_>,
103    input: &str,
104    errors: &mut Vec<SourceSpan>,
105) -> Pattern<SyntaxNode> {
106    let left = node
107        .child_by_field_name("left")
108        .and_then(|child| convert_named_node(child, input, errors))
109        .unwrap_or_else(|| {
110            record_error(errors, node);
111            fallback_pattern(node)
112        });
113    let right = node
114        .child_by_field_name("right")
115        .and_then(|child| convert_named_node(child, input, errors))
116        .unwrap_or_else(|| {
117            record_error(errors, node);
118            fallback_pattern(node)
119        });
120    let arrow_node = node.child_by_field_name("kind");
121    let arrow_kind = arrow_node
122        .map(|kind| arrow_kind(kind.kind(), kind, errors))
123        .unwrap_or_else(|| {
124            record_error(errors, node);
125            ArrowKind::Right
126        });
127
128    Pattern::pattern(
129        SyntaxNode {
130            kind: SyntaxKind::Relationship(arrow_kind),
131            subject: arrow_node.and_then(|kind| extract_subject(kind, input)),
132            span: span_from_node(node),
133            annotations: vec![],
134            text: None,
135        },
136        vec![left, right],
137    )
138}
139
140fn convert_subject_pattern(
141    node: Node<'_>,
142    input: &str,
143    errors: &mut Vec<SourceSpan>,
144) -> Pattern<SyntaxNode> {
145    let mut elements = Vec::new();
146
147    let mut node_cursor = node.walk();
148    let elements_node = node
149        .children(&mut node_cursor)
150        .find(|child| child.is_named() && child.kind() == "subject_pattern_elements");
151
152    if let Some(elements_node) = elements_node {
153        let mut cursor = elements_node.walk();
154        for child in elements_node.children(&mut cursor) {
155            if !child.is_named() {
156                continue;
157            }
158
159            match child.kind() {
160                "pattern_reference"
161                | "node_pattern"
162                | "relationship_pattern"
163                | "subject_pattern"
164                | "annotated_pattern" => {
165                    if let Some(element) = convert_named_node(child, input, errors) {
166                        elements.push(element);
167                    }
168                }
169                _ => {}
170            }
171        }
172    }
173
174    Pattern::pattern(
175        SyntaxNode {
176            kind: SyntaxKind::Subject,
177            subject: extract_subject(node, input),
178            span: span_from_node(node),
179            annotations: vec![],
180            text: None,
181        },
182        elements,
183    )
184}
185
186fn convert_annotated_pattern(
187    node: Node<'_>,
188    input: &str,
189    errors: &mut Vec<SourceSpan>,
190) -> Pattern<SyntaxNode> {
191    let annotations = node
192        .child_by_field_name("annotations")
193        .map(|annotations_node| extract_annotations(annotations_node, input, errors))
194        .unwrap_or_default();
195    let inner = node
196        .child_by_field_name("elements")
197        .and_then(|child| convert_named_node(child, input, errors));
198
199    Pattern::pattern(
200        SyntaxNode {
201            kind: SyntaxKind::Annotated,
202            subject: None,
203            span: span_from_node(node),
204            annotations,
205            text: None,
206        },
207        inner.into_iter().collect(),
208    )
209}
210
211fn convert_comment(node: Node<'_>, input: &str) -> Pattern<SyntaxNode> {
212    Pattern::point(SyntaxNode {
213        kind: SyntaxKind::Comment,
214        subject: None,
215        span: span_from_node(node),
216        annotations: vec![],
217        text: Some(node_text(node, input).to_string()),
218    })
219}
220
221fn convert_pattern_reference(
222    node: Node<'_>,
223    input: &str,
224    errors: &mut Vec<SourceSpan>,
225) -> Pattern<SyntaxNode> {
226    let identifier = node
227        .child_by_field_name("identifier")
228        .map(|child| extract_identifier(child, input))
229        .or_else(|| {
230            record_error(errors, node);
231            let raw = node_text(node, input).trim();
232            (!raw.is_empty()).then(|| raw.to_string())
233        })
234        .unwrap_or_default();
235
236    Pattern::point(SyntaxNode {
237        kind: SyntaxKind::Node,
238        subject: Some(Subject {
239            identity: pattern_core::Symbol(identifier),
240            labels: HashSet::new(),
241            properties: Default::default(),
242        }),
243        span: span_from_node(node),
244        annotations: vec![],
245        text: None,
246    })
247}
248
249fn extract_annotations(
250    node: Node<'_>,
251    input: &str,
252    errors: &mut Vec<SourceSpan>,
253) -> Vec<Annotation> {
254    let mut annotations = Vec::new();
255    let mut cursor = node.walk();
256
257    for child in node.children(&mut cursor) {
258        if !child.is_named() {
259            continue;
260        }
261
262        match child.kind() {
263            "property_annotation" => {
264                annotations.push(extract_property_annotation(child, input, errors))
265            }
266            "identified_annotation" => {
267                annotations.push(extract_identified_annotation(child, input))
268            }
269            _ => {}
270        }
271    }
272
273    annotations
274}
275
276fn extract_property_annotation(
277    node: Node<'_>,
278    input: &str,
279    errors: &mut Vec<SourceSpan>,
280) -> Annotation {
281    let key = node
282        .child_by_field_name("key")
283        .map(|child| node_text(child, input).to_string())
284        .unwrap_or_else(|| {
285            record_error(errors, node);
286            String::new()
287        });
288    let value = node
289        .child_by_field_name("value")
290        .map(|value_node| extract_annotation_value(value_node, input))
291        .unwrap_or(Value::Boolean(true));
292
293    Annotation::Property { key, value }
294}
295
296fn extract_identified_annotation(node: Node<'_>, input: &str) -> Annotation {
297    let identity = node
298        .child_by_field_name("identifier")
299        .map(|child| pattern_core::Symbol(extract_identifier(child, input)));
300    let labels = node
301        .child_by_field_name("labels")
302        .map(|labels| extract_label_list(labels, input))
303        .unwrap_or_default();
304
305    Annotation::Identified { identity, labels }
306}
307
308fn extract_annotation_value(node: Node<'_>, input: &str) -> Value {
309    let raw = node_text(node, input);
310    let parsed = crate::parser::value::value_parser(raw)
311        .ok()
312        .and_then(|(remaining, value)| remaining.trim().is_empty().then_some(value));
313
314    match parsed {
315        Some(pattern_core::Value::VString(value)) => Value::String(value),
316        Some(pattern_core::Value::VSymbol(value)) => Value::String(value),
317        Some(pattern_core::Value::VInteger(value)) => Value::Integer(value),
318        Some(pattern_core::Value::VDecimal(value)) => Value::Decimal(value),
319        Some(pattern_core::Value::VBoolean(value)) => Value::Boolean(value),
320        Some(pattern_core::Value::VArray(values)) => Value::Array(
321            values
322                .into_iter()
323                .map(pattern_value_to_annotation_value)
324                .collect(),
325        ),
326        Some(pattern_core::Value::VRange(range)) => match (range.lower, range.upper) {
327            (Some(lower), Some(upper)) if lower.fract() == 0.0 && upper.fract() == 0.0 => {
328                Value::Range {
329                    lower: lower as i64,
330                    upper: upper as i64,
331                }
332            }
333            _ => Value::String(raw.to_string()),
334        },
335        Some(pattern_core::Value::VTaggedString { tag, content }) => {
336            Value::TaggedString { tag, content }
337        }
338        Some(pattern_core::Value::VMap(_)) | Some(pattern_core::Value::VMeasurement { .. }) => {
339            Value::String(raw.to_string())
340        }
341        None => Value::String(raw.to_string()),
342    }
343}
344
345fn pattern_value_to_annotation_value(value: pattern_core::Value) -> Value {
346    match value {
347        pattern_core::Value::VString(value) => Value::String(value),
348        pattern_core::Value::VSymbol(value) => Value::String(value),
349        pattern_core::Value::VInteger(value) => Value::Integer(value),
350        pattern_core::Value::VDecimal(value) => Value::Decimal(value),
351        pattern_core::Value::VBoolean(value) => Value::Boolean(value),
352        pattern_core::Value::VArray(values) => Value::Array(
353            values
354                .into_iter()
355                .map(pattern_value_to_annotation_value)
356                .collect(),
357        ),
358        pattern_core::Value::VRange(range) => match (range.lower, range.upper) {
359            (Some(lower), Some(upper)) if lower.fract() == 0.0 && upper.fract() == 0.0 => {
360                Value::Range {
361                    lower: lower as i64,
362                    upper: upper as i64,
363                }
364            }
365            _ => Value::String(format!("{range}")),
366        },
367        pattern_core::Value::VTaggedString { tag, content } => Value::TaggedString { tag, content },
368        pattern_core::Value::VMap(map) => Value::String(pattern_core::Value::VMap(map).to_string()),
369        pattern_core::Value::VMeasurement { unit, value } => {
370            Value::String(format!("{value}{unit}"))
371        }
372    }
373}
374
375fn extract_subject(node: Node<'_>, input: &str) -> Option<Subject> {
376    let has_identifier = node.child_by_field_name("identifier").is_some();
377    let has_labels = node.child_by_field_name("labels").is_some();
378    let has_record = node.child_by_field_name("record").is_some();
379    let has_subject = node.child_by_field_name("subject").is_some();
380
381    if !(has_identifier || has_labels || has_record || has_subject) {
382        return None;
383    }
384
385    let identity = node
386        .child_by_field_name("identifier")
387        .map(|child| pattern_core::Symbol(extract_identifier(child, input)))
388        .unwrap_or_else(|| pattern_core::Symbol(String::new()));
389    let labels = node
390        .child_by_field_name("labels")
391        .map(|labels_node| extract_labels(labels_node, input))
392        .unwrap_or_default();
393    let properties = node
394        .child_by_field_name("record")
395        .map(|record| extract_record(record, input))
396        .unwrap_or_default();
397
398    Some(Subject {
399        identity,
400        labels,
401        properties,
402    })
403}
404
405fn extract_record_subject(node: Node<'_>, input: &str) -> Subject {
406    Subject {
407        identity: pattern_core::Symbol(String::new()),
408        labels: HashSet::new(),
409        properties: extract_record(node, input),
410    }
411}
412
413fn extract_record(node: Node<'_>, input: &str) -> pattern_core::PropertyRecord {
414    let raw = node_text(node, input);
415    crate::parser::subject::record(raw)
416        .ok()
417        .and_then(|(remaining, record)| remaining.trim().is_empty().then_some(record))
418        .unwrap_or_default()
419}
420
421fn extract_labels(node: Node<'_>, input: &str) -> HashSet<String> {
422    extract_label_list(node, input).into_iter().collect()
423}
424
425fn extract_label_list(node: Node<'_>, input: &str) -> Vec<String> {
426    let mut labels = Vec::new();
427    let mut cursor = node.walk();
428
429    for child in node.children(&mut cursor) {
430        if !child.is_named() {
431            continue;
432        }
433        match child.kind() {
434            "symbol" => labels.push(node_text(child, input).to_string()),
435            "quoted_name" => labels.push(extract_identifier(child, input)),
436            _ => {}
437        }
438    }
439
440    labels
441}
442
443fn extract_identifier(node: Node<'_>, input: &str) -> String {
444    let raw = node_text(node, input);
445    crate::parser::value::identifier(raw)
446        .ok()
447        .and_then(|(remaining, identifier)| remaining.trim().is_empty().then_some(identifier))
448        .unwrap_or_else(|| raw.to_string())
449}
450
451fn collect_error_spans(node: Node<'_>) -> Vec<SourceSpan> {
452    let mut spans = Vec::new();
453    collect_error_spans_inner(node, &mut spans);
454    spans
455}
456
457fn collect_error_spans_inner(node: Node<'_>, spans: &mut Vec<SourceSpan>) {
458    if node.is_error() {
459        spans.push(span_from_node(node));
460    }
461
462    if !(node.is_error() || node.has_error()) {
463        return;
464    }
465
466    let mut cursor = node.walk();
467    for child in node.children(&mut cursor) {
468        if child.is_error() || child.has_error() {
469            collect_error_spans_inner(child, spans);
470        }
471    }
472}
473
474fn arrow_kind(kind: &str, node: Node<'_>, errors: &mut Vec<SourceSpan>) -> ArrowKind {
475    match kind {
476        "right_arrow" => ArrowKind::Right,
477        "left_arrow" => ArrowKind::Left,
478        "bidirectional_arrow" => ArrowKind::Bidirectional,
479        "undirected_arrow" => ArrowKind::Undirected,
480        _ => {
481            record_error(errors, node);
482            ArrowKind::Right
483        }
484    }
485}
486
487fn span_from_node(node: Node<'_>) -> SourceSpan {
488    SourceSpan {
489        start: node.start_byte(),
490        end: node.end_byte(),
491    }
492}
493
494fn node_text<'a>(node: Node<'_>, input: &'a str) -> &'a str {
495    node.utf8_text(input.as_bytes()).unwrap_or("")
496}
497
498fn document_tree(
499    _input: &str,
500    span: SourceSpan,
501    subject: Option<Subject>,
502    elements: Vec<Pattern<SyntaxNode>>,
503) -> Pattern<SyntaxNode> {
504    Pattern::pattern(
505        SyntaxNode {
506            kind: SyntaxKind::Document,
507            subject,
508            span,
509            annotations: vec![],
510            text: None,
511        },
512        elements,
513    )
514}
515
516fn fallback_pattern(node: Node<'_>) -> Pattern<SyntaxNode> {
517    Pattern::point(SyntaxNode {
518        kind: SyntaxKind::Node,
519        subject: None,
520        span: span_from_node(node),
521        annotations: vec![],
522        text: None,
523    })
524}
525
526fn whole_input_span(input: &str) -> SourceSpan {
527    SourceSpan {
528        start: 0,
529        end: input.len(),
530    }
531}
532
533fn record_error(errors: &mut Vec<SourceSpan>, node: Node<'_>) {
534    errors.push(span_from_node(node));
535}
536
537fn dedupe_errors(errors: &mut Vec<SourceSpan>) {
538    errors.sort_by_key(|span| (span.start, span.end));
539    errors.dedup_by(|left, right| left.start == right.start && left.end == right.end);
540}