1use std::collections::HashMap;
4use std::fmt::{Debug, Display};
5use std::ops::{Bound, RangeBounds};
6use std::sync::OnceLock;
7
8use documented::DocumentedVariants;
9use proc_macro2::{Ident, Literal, Span, TokenStream};
10use quote::quote_spanned;
11use serde::{Deserialize, Serialize};
12use slotmap::Key;
13use syn::punctuated::Punctuated;
14use syn::{Expr, Token, parse_quote_spanned};
15
16use super::{
17 GraphLoopId, GraphNode, GraphNodeId, GraphSubgraphId, OpInstGenerics, OperatorInstance,
18 PortIndexValue,
19};
20use crate::diagnostic::{Diagnostic, Diagnostics, Level};
21use crate::parse::{Operator, PortIndex};
22
23#[derive(Clone, Copy, PartialOrd, Ord, PartialEq, Eq, Debug, Serialize, Deserialize)]
25pub enum DelayType {
26 Stratum,
28 MonotoneAccum,
30 Tick,
32 TickLazy,
34}
35
36pub enum PortListSpec {
38 Variadic,
40 Fixed(Punctuated<PortIndex, Token![,]>),
42}
43
44pub struct OperatorConstraints {
46 pub name: &'static str,
48 pub categories: &'static [OperatorCategory],
50
51 pub hard_range_inn: &'static dyn RangeTrait<usize>,
54 pub soft_range_inn: &'static dyn RangeTrait<usize>,
56 pub hard_range_out: &'static dyn RangeTrait<usize>,
58 pub soft_range_out: &'static dyn RangeTrait<usize>,
60 pub num_args: usize,
62 pub persistence_args: &'static dyn RangeTrait<usize>,
64 pub type_args: &'static dyn RangeTrait<usize>,
68 pub is_external_input: bool,
71 pub has_singleton_output: bool,
75 pub flo_type: Option<FloType>,
77
78 pub ports_inn: Option<fn() -> PortListSpec>,
80 pub ports_out: Option<fn() -> PortListSpec>,
82
83 pub input_delaytype_fn: fn(&PortIndexValue) -> Option<DelayType>,
85 pub write_fn: WriteFn,
87}
88
89pub type WriteFn = fn(&WriteContextArgs<'_>, &mut Diagnostics) -> Result<OperatorWriteOutput, ()>;
91
92impl Debug for OperatorConstraints {
93 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
94 f.debug_struct("OperatorConstraints")
95 .field("name", &self.name)
96 .field("hard_range_inn", &self.hard_range_inn)
97 .field("soft_range_inn", &self.soft_range_inn)
98 .field("hard_range_out", &self.hard_range_out)
99 .field("soft_range_out", &self.soft_range_out)
100 .field("num_args", &self.num_args)
101 .field("persistence_args", &self.persistence_args)
102 .field("type_args", &self.type_args)
103 .field("is_external_input", &self.is_external_input)
104 .field("ports_inn", &self.ports_inn)
105 .field("ports_out", &self.ports_out)
106 .finish()
110 }
111}
112
113#[derive(Default)]
117pub struct OperatorWriteOutput {
118 pub write_prologue: TokenStream,
121 pub write_iterator: TokenStream,
128 pub write_iterator_after: TokenStream,
130 pub write_tick_end: TokenStream,
133}
134
135pub const RANGE_ANY: &'static dyn RangeTrait<usize> = &(0..);
137pub const RANGE_0: &'static dyn RangeTrait<usize> = &(0..=0);
139pub const RANGE_1: &'static dyn RangeTrait<usize> = &(1..=1);
141
142pub fn identity_write_iterator_fn(
145 &WriteContextArgs {
146 root,
147 op_span,
148 ident,
149 inputs,
150 outputs,
151 is_pull,
152 op_inst:
153 OperatorInstance {
154 generics: OpInstGenerics { type_args, .. },
155 ..
156 },
157 ..
158 }: &WriteContextArgs,
159) -> TokenStream {
160 let generic_type = type_args
161 .first()
162 .map(quote::ToTokens::to_token_stream)
163 .unwrap_or(quote_spanned!(op_span=> _));
164
165 if is_pull {
166 let input = &inputs[0];
167 quote_spanned! {op_span=>
168 let #ident = {
169 fn check_input<Pull, Item>(pull: Pull) -> impl #root::dfir_pipes::pull::Pull<Item = Item, Meta = Pull::Meta, CanPend = Pull::CanPend, CanEnd = Pull::CanEnd>
170 where
171 Pull: #root::dfir_pipes::pull::Pull<Item = Item>,
172 {
173 pull
174 }
175 check_input::<_, #generic_type>(#input)
176 };
177 }
178 } else {
179 let output = &outputs[0];
180 quote_spanned! {op_span=>
181 let #ident = {
182 fn check_output<Psh, Item>(push: Psh) -> impl #root::dfir_pipes::push::Push<Item, (), CanPend = Psh::CanPend>
183 where
184 Psh: #root::dfir_pipes::push::Push<Item, ()>,
185 {
186 push
187 }
188 check_output::<_, #generic_type>(#output)
189 };
190 }
191 }
192}
193
194pub const IDENTITY_WRITE_FN: WriteFn = |write_context_args, _| {
196 let write_iterator = identity_write_iterator_fn(write_context_args);
197 Ok(OperatorWriteOutput {
198 write_iterator,
199 ..Default::default()
200 })
201};
202
203pub fn null_write_iterator_fn(
206 &WriteContextArgs {
207 root,
208 op_span,
209 ident,
210 inputs,
211 outputs,
212 is_pull,
213 op_inst:
214 OperatorInstance {
215 generics: OpInstGenerics { type_args, .. },
216 ..
217 },
218 ..
219 }: &WriteContextArgs,
220) -> TokenStream {
221 let default_type = parse_quote_spanned! {op_span=> _};
222 let iter_type = type_args.first().unwrap_or(&default_type);
223
224 if is_pull {
225 quote_spanned! {op_span=>
226 let #ident = #root::dfir_pipes::pull::poll_fn({
227 #(
228 let mut #inputs = ::std::boxed::Box::pin(#inputs);
229 )*
230 move |_cx| {
231 #(
235 let #inputs = #root::dfir_pipes::pull::Pull::pull(
236 ::std::pin::Pin::as_mut(&mut #inputs),
237 <_ as #root::dfir_pipes::Context>::from_task(_cx),
238 );
239 )*
240 #(
241 if let #root::dfir_pipes::pull::PullStep::Pending(_) = #inputs {
242 return #root::dfir_pipes::pull::PullStep::Pending(#root::dfir_pipes::Yes);
243 }
244 )*
245 #root::dfir_pipes::pull::PullStep::<_, _, #root::dfir_pipes::Yes, _>::Ended(#root::dfir_pipes::Yes)
246 }
247 });
248 }
249 } else {
250 quote_spanned! {op_span=>
251 #[allow(clippy::let_unit_value)]
252 let _ = (#(#outputs),*);
253 let #ident = #root::dfir_pipes::push::for_each::<_, #iter_type>(::std::mem::drop::<#iter_type>);
254 }
255 }
256}
257
258pub const NULL_WRITE_FN: WriteFn = |write_context_args, _| {
261 let write_iterator = null_write_iterator_fn(write_context_args);
262 Ok(OperatorWriteOutput {
263 write_iterator,
264 ..Default::default()
265 })
266};
267
268macro_rules! declare_ops {
269 ( $( $mod:ident :: $op:ident, )* ) => {
270 $( pub(crate) mod $mod; )*
271 pub const OPERATORS: &[OperatorConstraints] = &[
273 $( $mod :: $op, )*
274 ];
275 };
276}
277declare_ops![
278 all_iterations::ALL_ITERATIONS,
279 all_once::ALL_ONCE,
280 anti_join::ANTI_JOIN,
281 assert::ASSERT,
282 assert_eq::ASSERT_EQ,
283 batch::BATCH,
284 chain::CHAIN,
285 chain_first_n::CHAIN_FIRST_N,
286 _counter::_COUNTER,
287 cross_join::CROSS_JOIN,
288 cross_join_multiset::CROSS_JOIN_MULTISET,
289 cross_singleton::CROSS_SINGLETON,
290 demux_enum::DEMUX_ENUM,
291 dest_file::DEST_FILE,
292 dest_sink::DEST_SINK,
293 dest_sink_serde::DEST_SINK_SERDE,
294 difference::DIFFERENCE,
295 enumerate::ENUMERATE,
296 filter::FILTER,
297 filter_map::FILTER_MAP,
298 flat_map::FLAT_MAP,
299 flat_map_stream_blocking::FLAT_MAP_STREAM_BLOCKING,
300 flatten::FLATTEN,
301 flatten_stream_blocking::FLATTEN_STREAM_BLOCKING,
302 fold::FOLD,
303 fold_no_replay::FOLD_NO_REPLAY,
304 for_each::FOR_EACH,
305 identity::IDENTITY,
306 initialize::INITIALIZE,
307 inspect::INSPECT,
308 join::JOIN,
309 join_fused::JOIN_FUSED,
310 join_fused_lhs::JOIN_FUSED_LHS,
311 join_fused_rhs::JOIN_FUSED_RHS,
312 join_multiset::JOIN_MULTISET,
313 join_multiset_half::JOIN_MULTISET_HALF,
314 fold_keyed::FOLD_KEYED,
315 reduce_keyed::REDUCE_KEYED,
316 repeat_n::REPEAT_N,
317 lattice_bimorphism::LATTICE_BIMORPHISM,
319 _lattice_fold_batch::_LATTICE_FOLD_BATCH,
320 lattice_fold::LATTICE_FOLD,
321 _lattice_join_fused_join::_LATTICE_JOIN_FUSED_JOIN,
322 lattice_reduce::LATTICE_REDUCE,
323 map::MAP,
324 union::UNION,
325 multiset_delta::MULTISET_DELTA,
326 next_iteration::NEXT_ITERATION,
327 defer_signal::DEFER_SIGNAL,
328 defer_tick::DEFER_TICK,
329 defer_tick_lazy::DEFER_TICK_LAZY,
330 null::NULL,
331 partition::PARTITION,
332 persist::PERSIST,
333 persist_mut::PERSIST_MUT,
334 persist_mut_keyed::PERSIST_MUT_KEYED,
335 prefix::PREFIX,
336 resolve_futures::RESOLVE_FUTURES,
337 resolve_futures_blocking::RESOLVE_FUTURES_BLOCKING,
338 resolve_futures_blocking_ordered::RESOLVE_FUTURES_BLOCKING_ORDERED,
339 resolve_futures_ordered::RESOLVE_FUTURES_ORDERED,
340 reduce::REDUCE,
341 reduce_no_replay::REDUCE_NO_REPLAY,
342 scan::SCAN,
343 scan_async_blocking::SCAN_ASYNC_BLOCKING,
344 spin::SPIN,
345 sort::SORT,
346 sort_by_key::SORT_BY_KEY,
347 source_file::SOURCE_FILE,
348 source_interval::SOURCE_INTERVAL,
349 source_iter::SOURCE_ITER,
350 source_json::SOURCE_JSON,
351 source_stdin::SOURCE_STDIN,
352 source_stream::SOURCE_STREAM,
353 source_stream_serde::SOURCE_STREAM_SERDE,
354 state::STATE,
355 state_by::STATE_BY,
356 tee::TEE,
357 unique::UNIQUE,
358 unzip::UNZIP,
359 zip::ZIP,
360 zip_longest::ZIP_LONGEST,
361];
362
363pub fn operator_lookup() -> &'static HashMap<&'static str, &'static OperatorConstraints> {
365 pub static OPERATOR_LOOKUP: OnceLock<HashMap<&'static str, &'static OperatorConstraints>> =
366 OnceLock::new();
367 OPERATOR_LOOKUP.get_or_init(|| OPERATORS.iter().map(|op| (op.name, op)).collect())
368}
369pub fn find_node_op_constraints(node: &GraphNode) -> Option<&'static OperatorConstraints> {
371 if let GraphNode::Operator(operator) = node {
372 find_op_op_constraints(operator)
373 } else {
374 None
375 }
376}
377pub fn find_op_op_constraints(operator: &Operator) -> Option<&'static OperatorConstraints> {
379 let name = &*operator.name_string();
380 operator_lookup().get(name).copied()
381}
382
383#[derive(Clone)]
385pub struct WriteContextArgs<'a> {
386 pub root: &'a TokenStream,
388 pub context: &'a Ident,
391 pub df_ident: &'a Ident,
395 pub subgraph_id: GraphSubgraphId,
397 pub node_id: GraphNodeId,
399 pub loop_id: Option<GraphLoopId>,
401 pub op_span: Span,
403 pub op_tag: Option<String>,
405 pub work_fn: &'a Ident,
407 pub work_fn_async: &'a Ident,
409
410 pub ident: &'a Ident,
412 pub is_pull: bool,
414 pub inputs: &'a [Ident],
416 pub outputs: &'a [Ident],
418 pub singleton_output_ident: &'a Ident,
420
421 pub op_name: &'static str,
423 pub op_inst: &'a OperatorInstance,
425 pub arguments: &'a Punctuated<Expr, Token![,]>,
431 pub arguments_handles: &'a Punctuated<Expr, Token![,]>,
433}
434impl WriteContextArgs<'_> {
435 pub fn make_ident(&self, suffix: impl AsRef<str>) -> Ident {
441 Ident::new(
442 &format!(
443 "sg_{:?}_node_{:?}_{}",
444 self.subgraph_id.data(),
445 self.node_id.data(),
446 suffix.as_ref(),
447 ),
448 self.op_span,
449 )
450 }
451
452 pub fn persistence_args_disallow_mutable<const N: usize>(
454 &self,
455 diagnostics: &mut Diagnostics,
456 ) -> [Persistence; N] {
457 let len = self.op_inst.generics.persistence_args.len();
458 if 0 != len && 1 != len && N != len {
459 diagnostics.push(Diagnostic::spanned(
460 self.op_span,
461 Level::Error,
462 format!(
463 "The operator `{}` only accepts 0, 1, or {} persistence arguments",
464 self.op_name, N
465 ),
466 ));
467 }
468
469 let default_persistence = if self.loop_id.is_some() {
470 Persistence::None
471 } else {
472 Persistence::Tick
473 };
474 let mut out = [default_persistence; N];
475 self.op_inst
476 .generics
477 .persistence_args
478 .iter()
479 .copied()
480 .cycle() .take(N)
482 .enumerate()
483 .filter(|&(_i, p)| {
484 if p == Persistence::Mutable {
485 diagnostics.push(Diagnostic::spanned(
486 self.op_span,
487 Level::Error,
488 format!(
489 "An implementation of `'{}` does not exist",
490 p.to_str_lowercase()
491 ),
492 ));
493 false
494 } else {
495 true
496 }
497 })
498 .for_each(|(i, p)| {
499 out[i] = p;
500 });
501 out
502 }
503}
504
505pub trait RangeTrait<T>: Send + Sync + Debug
507where
508 T: ?Sized,
509{
510 fn start_bound(&self) -> Bound<&T>;
512 fn end_bound(&self) -> Bound<&T>;
514 fn contains(&self, item: &T) -> bool
516 where
517 T: PartialOrd<T>;
518
519 fn human_string(&self) -> String
521 where
522 T: Display + PartialEq,
523 {
524 match (self.start_bound(), self.end_bound()) {
525 (Bound::Unbounded, Bound::Unbounded) => "any number of".to_owned(),
526
527 (Bound::Included(n), Bound::Included(x)) if n == x => {
528 format!("exactly {}", n)
529 }
530 (Bound::Included(n), Bound::Included(x)) => {
531 format!("at least {} and at most {}", n, x)
532 }
533 (Bound::Included(n), Bound::Excluded(x)) => {
534 format!("at least {} and less than {}", n, x)
535 }
536 (Bound::Included(n), Bound::Unbounded) => format!("at least {}", n),
537 (Bound::Excluded(n), Bound::Included(x)) => {
538 format!("more than {} and at most {}", n, x)
539 }
540 (Bound::Excluded(n), Bound::Excluded(x)) => {
541 format!("more than {} and less than {}", n, x)
542 }
543 (Bound::Excluded(n), Bound::Unbounded) => format!("more than {}", n),
544 (Bound::Unbounded, Bound::Included(x)) => format!("at most {}", x),
545 (Bound::Unbounded, Bound::Excluded(x)) => format!("less than {}", x),
546 }
547 }
548}
549
550impl<R, T> RangeTrait<T> for R
551where
552 R: RangeBounds<T> + Send + Sync + Debug,
553{
554 fn start_bound(&self) -> Bound<&T> {
555 self.start_bound()
556 }
557
558 fn end_bound(&self) -> Bound<&T> {
559 self.end_bound()
560 }
561
562 fn contains(&self, item: &T) -> bool
563 where
564 T: PartialOrd<T>,
565 {
566 self.contains(item)
567 }
568}
569
570#[derive(Clone, Copy, PartialOrd, Ord, PartialEq, Eq, Debug, Serialize, Deserialize)]
572pub enum Persistence {
573 None,
575 Loop,
577 Tick,
579 Static,
581 Mutable,
583}
584impl Persistence {
585 pub fn to_str_lowercase(self) -> &'static str {
587 match self {
588 Persistence::None => "none",
589 Persistence::Tick => "tick",
590 Persistence::Loop => "loop",
591 Persistence::Static => "static",
592 Persistence::Mutable => "mutable",
593 }
594 }
595}
596
597fn make_missing_runtime_msg(op_name: &str) -> Literal {
599 Literal::string(&format!(
600 "`{}()` must be used within a Tokio runtime. For example, use `#[dfir_rs::main]` on your main method.",
601 op_name
602 ))
603}
604
605#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, DocumentedVariants)]
607pub enum OperatorCategory {
608 Map,
610 Filter,
612 Flatten,
614 Fold,
616 KeyedFold,
618 LatticeFold,
620 Persistence,
622 MultiIn,
624 MultiOut,
626 Source,
628 Sink,
630 Control,
632 CompilerFusionOperator,
634 Windowing,
636 Unwindowing,
638}
639impl OperatorCategory {
640 pub fn name(self) -> &'static str {
642 self.get_variant_docs().split_once(":").unwrap().0
643 }
644 pub fn description(self) -> &'static str {
646 self.get_variant_docs().split_once(":").unwrap().1
647 }
648}
649
650#[derive(Clone, Copy, PartialOrd, Ord, PartialEq, Eq, Debug)]
652pub enum FloType {
653 Source,
655 Windowing,
657 Unwindowing,
659 NextIteration,
661}