Skip to main content

gram_codec/cst/
lowering.rs

1//! Lowering from syntax-preserving CST nodes to semantic patterns.
2
3use crate::cst::{Annotation, ArrowKind, SyntaxKind, SyntaxNode};
4use crate::{Pattern, Subject};
5use pattern_core::{RangeValue, Symbol, Value};
6use std::collections::{HashMap, HashSet};
7
8pub fn lower(tree: Pattern<SyntaxNode>) -> Vec<Pattern<Subject>> {
9    assert!(
10        matches!(tree.value.kind, SyntaxKind::Document),
11        "lower expects a document root"
12    );
13
14    let mut lowered = Vec::new();
15
16    if let Some(subject) = tree.value.subject {
17        lowered.push(Pattern::point(subject));
18    }
19
20    for element in tree.elements {
21        if let Some(pattern) = lower_node(element) {
22            lowered.push(pattern);
23        }
24    }
25
26    lowered
27}
28
29fn lower_node(node: Pattern<SyntaxNode>) -> Option<Pattern<Subject>> {
30    match node.value.kind {
31        SyntaxKind::Document => unreachable!("document nodes are only handled at the root"),
32        SyntaxKind::Node => Some(Pattern::point(
33            node.value.subject.unwrap_or_else(empty_subject),
34        )),
35        SyntaxKind::Subject => Some(Pattern::pattern(
36            node.value.subject.unwrap_or_else(empty_subject),
37            node.elements
38                .into_iter()
39                .filter_map(lower_node)
40                .collect::<Vec<_>>(),
41        )),
42        SyntaxKind::Relationship(_) => Some(lower_relationship(node)),
43        SyntaxKind::Annotated => {
44            let mut elements = node.elements.into_iter().filter_map(lower_node);
45            let inner = elements.next()?;
46            Some(Pattern::pattern(
47                annotation_subject(&node.value.annotations),
48                vec![inner],
49            ))
50        }
51        SyntaxKind::Comment => None,
52    }
53}
54
55fn lower_relationship(node: Pattern<SyntaxNode>) -> Pattern<Subject> {
56    let (operands, relationships) = flatten_relationship_chain(node);
57    let mut operands = operands.into_iter();
58    let mut acc = lower_node(
59        operands
60            .next()
61            .expect("relationship chain should have a first operand"),
62    )
63    .expect("relationship operands should lower to patterns");
64
65    for ((arrow_kind, subject), operand) in relationships.into_iter().zip(operands) {
66        let next =
67            lower_node(operand).expect("relationship chain operands should lower to patterns");
68        let elements = if matches!(arrow_kind, ArrowKind::Left) {
69            vec![next, acc]
70        } else {
71            vec![acc, next]
72        };
73        acc = Pattern::pattern(subject, elements);
74    }
75
76    acc
77}
78
79fn flatten_relationship_chain(
80    node: Pattern<SyntaxNode>,
81) -> (Vec<Pattern<SyntaxNode>>, Vec<(ArrowKind, Subject)>) {
82    let arrow_kind = match node.value.kind {
83        SyntaxKind::Relationship(arrow_kind) => arrow_kind,
84        _ => unreachable!("flatten_relationship_chain only accepts relationships"),
85    };
86
87    let mut elements = node.elements.into_iter();
88    let left = elements
89        .next()
90        .expect("relationship nodes should have a left operand");
91    let right = elements
92        .next()
93        .expect("relationship nodes should have a right operand");
94
95    assert!(
96        elements.next().is_none(),
97        "relationship nodes should have exactly two operands"
98    );
99
100    let mut operands = vec![left];
101    let mut relationships = vec![(arrow_kind, node.value.subject.unwrap_or_else(empty_subject))];
102
103    match right.value.kind {
104        SyntaxKind::Relationship(_) => {
105            let (mut child_operands, mut child_relationships) = flatten_relationship_chain(right);
106            operands.append(&mut child_operands);
107            relationships.append(&mut child_relationships);
108        }
109        _ => operands.push(right),
110    }
111
112    (operands, relationships)
113}
114
115fn annotation_subject(annotations: &[Annotation]) -> Subject {
116    let mut identity = Symbol(String::new());
117    let mut labels = HashSet::new();
118    let mut properties = HashMap::new();
119
120    for annotation in annotations {
121        match annotation {
122            Annotation::Property { key, value } => {
123                properties.insert(key.clone(), lower_annotation_value(value));
124            }
125            Annotation::Identified {
126                identity: annotation_identity,
127                labels: annotation_labels,
128            } => {
129                if let Some(annotation_identity) = annotation_identity {
130                    identity = annotation_identity.clone();
131                }
132
133                for label in annotation_labels {
134                    labels.insert(label.clone());
135                }
136            }
137        }
138    }
139
140    Subject {
141        identity,
142        labels,
143        properties,
144    }
145}
146
147fn lower_annotation_value(value: &crate::Value) -> Value {
148    match value {
149        crate::Value::String(value) => Value::VString(value.clone()),
150        crate::Value::Integer(value) => Value::VInteger(*value),
151        crate::Value::Decimal(value) => Value::VDecimal(*value),
152        crate::Value::Boolean(value) => Value::VBoolean(*value),
153        crate::Value::Array(values) => Value::VArray(
154            values
155                .iter()
156                .map(lower_annotation_value)
157                .collect::<Vec<_>>(),
158        ),
159        crate::Value::Range { lower, upper } => Value::VRange(RangeValue {
160            lower: Some(*lower as f64),
161            upper: Some(*upper as f64),
162        }),
163        crate::Value::TaggedString { tag, content } => Value::VTaggedString {
164            tag: tag.clone(),
165            content: content.clone(),
166        },
167    }
168}
169
170fn empty_subject() -> Subject {
171    Subject {
172        identity: Symbol(String::new()),
173        labels: Default::default(),
174        properties: Default::default(),
175    }
176}