infotheory/backends/
llm_policy.rs

1use anyhow::{Context, Result, bail};
2use std::collections::BTreeSet;
3use std::path::PathBuf;
4
5#[derive(Clone, Debug, PartialEq)]
6/// Position expression used in policy schedules.
7pub enum PositionExpr {
8    /// Absolute byte/token offset.
9    Bytes(u64),
10    /// Percentage of the known total stream length.
11    Percent(f64),
12}
13
14#[derive(Clone, Copy, Debug, PartialEq, Eq)]
15/// Optimizer family selected by a train action.
16pub enum OptimizerKind {
17    /// Stochastic gradient descent.
18    Sgd,
19    /// Adam optimizer.
20    Adam,
21}
22
23#[derive(Clone, Debug, PartialEq)]
24/// Hyper-parameters attached to a training action.
25pub struct OptimizerHyperParams {
26    /// Learning rate (clamped non-negative by parser).
27    pub lr: f32,
28    /// Apply one update every `stride` eligible steps.
29    pub stride: usize,
30    /// Truncated backprop length for full-parameter training.
31    pub bptt: usize,
32    /// Optional gradient clipping threshold (`0` disables clipping).
33    pub clip: f32,
34    /// Momentum used by SGD-style updates.
35    pub momentum: f32,
36}
37
38impl Default for OptimizerHyperParams {
39    fn default() -> Self {
40        Self {
41            lr: 0.001,
42            stride: 1,
43            bptt: 1,
44            clip: 0.0,
45            momentum: 0.9,
46        }
47    }
48}
49
50#[derive(Clone, Debug, PartialEq, Eq)]
51/// Canonical set of train scopes (`all`, `none`, or sorted named scopes).
52pub struct TrainScopeSet {
53    /// If true, every allowed scope is enabled.
54    pub all: bool,
55    /// Sorted explicit scope names used when `all == false`.
56    pub names: Vec<String>,
57}
58
59impl TrainScopeSet {
60    /// Enable all scopes.
61    pub fn all() -> Self {
62        Self {
63            all: true,
64            names: Vec::new(),
65        }
66    }
67
68    /// Disable all scopes.
69    pub fn none() -> Self {
70        Self {
71            all: false,
72            names: Vec::new(),
73        }
74    }
75
76    /// Returns whether `name` is enabled by this set.
77    pub fn contains(&self, name: &str) -> bool {
78        self.all
79            || self
80                .names
81                .binary_search_by(|s| s.as_str().cmp(name))
82                .is_ok()
83    }
84
85    /// Returns `true` when the set explicitly enables no scopes.
86    pub fn is_none(&self) -> bool {
87        !self.all && self.names.is_empty()
88    }
89
90    /// Parse a scope expression against an allow-list.
91    ///
92    /// Supports `all`, `none`, or a `+`/`|`/`/` separated list.
93    pub fn parse(value: &str, allowed_scopes: &[&str]) -> Result<Self> {
94        let v = value.trim().to_ascii_lowercase();
95        if v.is_empty() {
96            bail!("empty train scope");
97        }
98        if v == "all" {
99            return Ok(Self::all());
100        }
101        if v == "none" {
102            return Ok(Self::none());
103        }
104
105        let mut out = BTreeSet::<String>::new();
106        for tok in v.split(['+', '|', '/']) {
107            let t = tok.trim();
108            if t.is_empty() {
109                continue;
110            }
111            if !allowed_scopes.contains(&t) {
112                bail!(
113                    "unknown train scope '{t}', allowed: {}",
114                    allowed_scopes.join(",")
115                );
116            }
117            if t == "all" || t == "none" {
118                bail!("scope list cannot mix '{t}' with named scopes");
119            }
120            out.insert(t.to_string());
121        }
122
123        Ok(Self {
124            all: false,
125            names: out.into_iter().collect(),
126        })
127    }
128
129    /// Canonical string form (`all`, `none`, or `a+b+c`).
130    pub fn canonical(&self) -> String {
131        if self.all {
132            return "all".to_string();
133        }
134        if self.names.is_empty() {
135            return "none".to_string();
136        }
137        self.names.join("+")
138    }
139}
140
141#[derive(Clone, Debug, PartialEq)]
142/// Concrete training directive.
143pub struct TrainAction {
144    /// Parameter subsets to update.
145    pub scope: TrainScopeSet,
146    /// Optimizer family.
147    pub optimizer: OptimizerKind,
148    /// Optimizer hyper-parameters.
149    pub hyper: OptimizerHyperParams,
150}
151
152#[derive(Clone, Debug, PartialEq)]
153/// Action emitted by the policy at each stream position.
154pub enum PolicyAction {
155    /// Inference-only step (no adaptation).
156    Infer,
157    /// Training step with provided parameters.
158    Train(TrainAction),
159}
160
161#[derive(Clone, Debug, PartialEq)]
162/// Interval schedule rule (`start..end:action`).
163pub struct PolicyRule {
164    /// Inclusive interval start.
165    pub start: PositionExpr,
166    /// Exclusive interval end.
167    pub end: PositionExpr,
168    /// Action applied in the interval.
169    pub action: PolicyAction,
170}
171
172#[derive(Clone, Debug, PartialEq)]
173/// One segment inside a repeating pattern.
174pub struct RepeatSegment {
175    /// Segment duration.
176    pub span: PositionExpr,
177    /// Action active for this segment.
178    pub action: PolicyAction,
179}
180
181#[derive(Clone, Debug, PartialEq)]
182/// Repeat schedule rule with cycle period and inner pattern.
183pub struct RepeatRule {
184    /// Inclusive active-start boundary.
185    pub start: PositionExpr,
186    /// Exclusive active-end boundary.
187    pub end: PositionExpr,
188    /// Repeat cycle length.
189    pub period: PositionExpr,
190    /// Ordered segments inside one cycle.
191    pub pattern: Vec<RepeatSegment>,
192}
193
194#[derive(Clone, Debug, PartialEq)]
195/// Top-level schedule rule.
196pub enum ScheduleRule {
197    /// Single contiguous interval.
198    Interval(PolicyRule),
199    /// Periodic repeating schedule.
200    Repeat(RepeatRule),
201}
202
203#[derive(Clone, Debug, PartialEq)]
204/// Parsed policy specification for online LLM adaptation.
205pub struct LlmPolicy {
206    /// Optional initial weights/source to load before running schedule.
207    pub load_from: Option<PathBuf>,
208    /// Ordered schedule rules.
209    pub schedule: Vec<ScheduleRule>,
210}
211
212/// Returns `true` if any schedule branch can emit a training action.
213pub fn policy_can_train(policy: &LlmPolicy) -> bool {
214    for rule in &policy.schedule {
215        match rule {
216            ScheduleRule::Interval(interval) => {
217                if matches!(interval.action, PolicyAction::Train(_)) {
218                    return true;
219                }
220            }
221            ScheduleRule::Repeat(repeat) => {
222                if repeat
223                    .pattern
224                    .iter()
225                    .any(|seg| matches!(seg.action, PolicyAction::Train(_)))
226                {
227                    return true;
228                }
229            }
230        }
231    }
232    false
233}
234
235#[derive(Clone, Debug)]
236struct CompiledPatternSegment {
237    end: u64,
238    action: PolicyAction,
239}
240
241#[derive(Clone, Debug)]
242enum CompiledScheduleRule {
243    Interval {
244        start: u64,
245        end: u64,
246        action: PolicyAction,
247    },
248    Repeat {
249        start: u64,
250        end: u64,
251        period: u64,
252        pattern_total: u64,
253        pattern: Vec<CompiledPatternSegment>,
254    },
255}
256
257#[derive(Clone, Debug)]
258/// Compiled, position-resolved policy ready for runtime evaluation.
259pub struct CompiledPolicy {
260    rules: Vec<CompiledScheduleRule>,
261}
262
263impl CompiledPolicy {
264    /// Resolve the action at absolute position `pos`.
265    pub fn action_at(&self, pos: u64) -> PolicyAction {
266        for rule in &self.rules {
267            match rule {
268                CompiledScheduleRule::Interval { start, end, action }
269                    if pos >= *start && pos < *end =>
270                {
271                    return action.clone();
272                }
273                CompiledScheduleRule::Repeat {
274                    start,
275                    end,
276                    period,
277                    pattern_total,
278                    pattern,
279                } if pos >= *start && pos < *end => {
280                    let phase = (pos - *start) % *period;
281                    if phase >= *pattern_total {
282                        return PolicyAction::Infer;
283                    }
284                    for seg in pattern {
285                        if phase < seg.end {
286                            return seg.action.clone();
287                        }
288                    }
289                    return PolicyAction::Infer;
290                }
291                _ => {}
292            }
293        }
294        PolicyAction::Infer
295    }
296}
297
298#[derive(Clone, Debug)]
299/// Stateful cursor over a [`CompiledPolicy`].
300pub struct PolicyRuntime {
301    compiled: CompiledPolicy,
302    cursor: u64,
303}
304
305impl PolicyRuntime {
306    /// Create a runtime positioned at cursor `0`.
307    pub fn new(compiled: CompiledPolicy) -> Self {
308        Self {
309            compiled,
310            cursor: 0,
311        }
312    }
313
314    #[inline]
315    /// Current cursor position.
316    pub fn cursor(&self) -> u64 {
317        self.cursor
318    }
319
320    #[inline]
321    /// Set cursor position directly.
322    pub fn set_cursor(&mut self, cursor: u64) {
323        self.cursor = cursor;
324    }
325
326    #[inline]
327    /// Peek current action without advancing cursor.
328    pub fn peek_action(&self) -> PolicyAction {
329        self.compiled.action_at(self.cursor)
330    }
331
332    #[inline]
333    /// Return current action and advance cursor by one.
334    pub fn next_action(&mut self) -> PolicyAction {
335        let action = self.compiled.action_at(self.cursor);
336        self.cursor = self.cursor.saturating_add(1);
337        action
338    }
339}
340
341/// Split `method` into base method segment and optional `policy:...` segment.
342pub fn split_method_policy_segments(method: &str) -> Result<(String, Option<String>)> {
343    let trimmed = method.trim();
344    if trimmed.is_empty() {
345        bail!("empty method string");
346    }
347    let mut iter = trimmed.split(';');
348    let base = iter.next().unwrap_or_default().trim().to_string();
349    if base.is_empty() {
350        bail!("method is missing cfg/file segment");
351    }
352
353    let mut policy = None::<String>;
354    for seg in iter {
355        let s = seg.trim();
356        if s.is_empty() {
357            continue;
358        }
359        if let Some(rest) = s.strip_prefix("policy:") {
360            if policy.is_some() {
361                bail!("duplicate policy segment in method string");
362            }
363            policy = Some(rest.trim().to_string());
364            continue;
365        }
366        bail!("unknown method segment '{s}', expected 'policy:...'");
367    }
368
369    Ok((base, policy))
370}
371
372/// Parse a `policy:...` string into [`LlmPolicy`].
373pub fn parse_policy_segment(policy_segment: &str, allowed_scopes: &[&str]) -> Result<LlmPolicy> {
374    let body = policy_segment
375        .trim()
376        .strip_prefix("policy:")
377        .unwrap_or(policy_segment.trim())
378        .trim();
379    if body.is_empty() {
380        bail!("empty policy segment");
381    }
382
383    let mut load_from = None::<PathBuf>;
384    let mut schedule_raw = None::<String>;
385
386    for entry in split_top_level(body, ',')? {
387        let entry = entry.trim();
388        if entry.is_empty() {
389            continue;
390        }
391        let (k, v) = entry
392            .split_once('=')
393            .with_context(|| format!("invalid policy key/value pair '{entry}'"))?;
394        let key = k.trim().to_ascii_lowercase();
395        let val = v.trim();
396        match key.as_str() {
397            "load_from" => {
398                if val.is_empty() {
399                    bail!("policy load_from must not be empty");
400                }
401                load_from = Some(PathBuf::from(val));
402            }
403            "schedule" => {
404                if val.is_empty() {
405                    bail!("policy schedule must not be empty");
406                }
407                schedule_raw = Some(val.to_string());
408            }
409            other => bail!("unknown policy key '{other}'"),
410        }
411    }
412
413    let schedule_raw =
414        schedule_raw.ok_or_else(|| anyhow::anyhow!("policy requires schedule=..."))?;
415    let mut schedule = Vec::<ScheduleRule>::new();
416    for token in split_top_level(&schedule_raw, '|')? {
417        let t = token.trim();
418        if t.is_empty() {
419            continue;
420        }
421        if t.starts_with("repeat(") {
422            schedule.push(ScheduleRule::Repeat(parse_repeat_rule(t, allowed_scopes)?));
423        } else {
424            schedule.push(ScheduleRule::Interval(parse_interval_rule(
425                t,
426                allowed_scopes,
427            )?));
428        }
429    }
430
431    if schedule.is_empty() {
432        bail!("policy schedule must contain at least one rule");
433    }
434
435    Ok(LlmPolicy {
436        load_from,
437        schedule,
438    })
439}
440
441impl LlmPolicy {
442    /// Compile policy boundaries/spans using optional known total symbol count.
443    pub fn compile(&self, total_symbols: Option<u64>) -> Result<CompiledPolicy> {
444        let mut out = Vec::<CompiledScheduleRule>::with_capacity(self.schedule.len());
445        for rule in &self.schedule {
446            match rule {
447                ScheduleRule::Interval(r) => {
448                    let start = resolve_boundary(&r.start, total_symbols)?;
449                    let end = resolve_boundary(&r.end, total_symbols)?;
450                    if end <= start {
451                        bail!("invalid interval with end <= start ({start}..{end})");
452                    }
453                    out.push(CompiledScheduleRule::Interval {
454                        start,
455                        end,
456                        action: r.action.clone(),
457                    });
458                }
459                ScheduleRule::Repeat(r) => {
460                    let start = resolve_boundary(&r.start, total_symbols)?;
461                    let end = resolve_boundary(&r.end, total_symbols)?;
462                    if end <= start {
463                        bail!("invalid repeat interval with end <= start ({start}..{end})");
464                    }
465                    let period = resolve_span(&r.period, total_symbols)?;
466                    if period == 0 {
467                        bail!("repeat period must be > 0");
468                    }
469
470                    let mut pattern = Vec::<CompiledPatternSegment>::with_capacity(r.pattern.len());
471                    let mut acc = 0u64;
472                    for seg in &r.pattern {
473                        let span = resolve_span(&seg.span, total_symbols)?;
474                        if span == 0 {
475                            bail!("repeat pattern segment span must be > 0");
476                        }
477                        acc = acc.saturating_add(span);
478                        pattern.push(CompiledPatternSegment {
479                            end: acc,
480                            action: seg.action.clone(),
481                        });
482                    }
483                    if pattern.is_empty() {
484                        bail!("repeat pattern must contain at least one segment");
485                    }
486
487                    out.push(CompiledScheduleRule::Repeat {
488                        start,
489                        end,
490                        period,
491                        pattern_total: acc,
492                        pattern,
493                    });
494                }
495            }
496        }
497        Ok(CompiledPolicy { rules: out })
498    }
499
500    /// Serialize back to canonical `load_from=...,schedule=...` form.
501    pub fn canonical(&self) -> String {
502        let mut out = String::new();
503        if let Some(path) = &self.load_from {
504            out.push_str("load_from=");
505            out.push_str(&path.display().to_string());
506            out.push(',');
507        }
508        out.push_str("schedule=");
509        for (idx, r) in self.schedule.iter().enumerate() {
510            if idx > 0 {
511                out.push('|');
512            }
513            match r {
514                ScheduleRule::Interval(i) => {
515                    out.push_str(&position_to_string(&i.start));
516                    out.push_str("..");
517                    out.push_str(&position_to_string(&i.end));
518                    out.push(':');
519                    out.push_str(&action_to_string(&i.action));
520                }
521                ScheduleRule::Repeat(rep) => {
522                    out.push_str("repeat(");
523                    out.push_str(&position_to_string(&rep.start));
524                    out.push_str("..");
525                    out.push_str(&position_to_string(&rep.end));
526                    out.push_str(",period=");
527                    out.push_str(&position_to_string(&rep.period));
528                    out.push_str(",pattern=");
529                    for (j, seg) in rep.pattern.iter().enumerate() {
530                        if j > 0 {
531                            out.push('+');
532                        }
533                        out.push_str(&position_to_string(&seg.span));
534                        out.push(':');
535                        out.push_str(&action_to_string(&seg.action));
536                    }
537                    out.push(')');
538                }
539            }
540        }
541        out
542    }
543}
544
545fn parse_interval_rule(token: &str, allowed_scopes: &[&str]) -> Result<PolicyRule> {
546    let (range, action_s) = token.split_once(':').with_context(|| {
547        format!("invalid schedule rule '{token}', expected <start>..<end>:<action>")
548    })?;
549    let (start, end) = parse_range(range)?;
550    let action = parse_action(action_s, allowed_scopes)?;
551    Ok(PolicyRule { start, end, action })
552}
553
554fn parse_repeat_rule(token: &str, allowed_scopes: &[&str]) -> Result<RepeatRule> {
555    let inner = token
556        .strip_prefix("repeat(")
557        .and_then(|s| s.strip_suffix(')'))
558        .ok_or_else(|| {
559            anyhow::anyhow!(
560                "invalid repeat rule '{token}', expected repeat(<start>..<end>,period=...,pattern=...)"
561            )
562        })?;
563
564    let args = split_top_level(inner, ',')?;
565    if args.is_empty() {
566        bail!("repeat rule is empty");
567    }
568
569    let (start, end) = parse_range(args[0].trim())?;
570    let mut period = None::<PositionExpr>;
571    let mut pattern = None::<Vec<RepeatSegment>>;
572
573    for arg in args.into_iter().skip(1) {
574        let arg = arg.trim();
575        if arg.is_empty() {
576            continue;
577        }
578        let (k, v) = arg
579            .split_once('=')
580            .with_context(|| format!("invalid repeat argument '{arg}'"))?;
581        let key = k.trim().to_ascii_lowercase();
582        let val = v.trim();
583        match key.as_str() {
584            "period" => period = Some(parse_position_expr(val)?),
585            "pattern" => {
586                let mut segs = Vec::<RepeatSegment>::new();
587                for seg in split_top_level(val, '+')? {
588                    let s = seg.trim();
589                    if s.is_empty() {
590                        continue;
591                    }
592                    let (span_s, action_s) = s
593                        .split_once(':')
594                        .with_context(|| format!("invalid repeat pattern segment '{s}'"))?;
595                    segs.push(RepeatSegment {
596                        span: parse_position_expr(span_s.trim())?,
597                        action: parse_action(action_s.trim(), allowed_scopes)?,
598                    });
599                }
600                pattern = Some(segs);
601            }
602            other => bail!("unknown repeat key '{other}'"),
603        }
604    }
605
606    let period = period.ok_or_else(|| anyhow::anyhow!("repeat rule requires period=..."))?;
607    let pattern = pattern.ok_or_else(|| anyhow::anyhow!("repeat rule requires pattern=..."))?;
608    if pattern.is_empty() {
609        bail!("repeat pattern must not be empty");
610    }
611
612    Ok(RepeatRule {
613        start,
614        end,
615        period,
616        pattern,
617    })
618}
619
620fn parse_action(token: &str, allowed_scopes: &[&str]) -> Result<PolicyAction> {
621    let t = token.trim();
622    if t.eq_ignore_ascii_case("infer") {
623        return Ok(PolicyAction::Infer);
624    }
625
626    if t.eq_ignore_ascii_case("train") {
627        return Ok(PolicyAction::Train(TrainAction {
628            scope: TrainScopeSet::all(),
629            optimizer: OptimizerKind::Sgd,
630            hyper: OptimizerHyperParams::default(),
631        }));
632    }
633
634    let inner = t
635        .strip_prefix("train(")
636        .and_then(|s| s.strip_suffix(')'))
637        .ok_or_else(|| anyhow::anyhow!("invalid action '{token}', expected infer or train(...)"))?;
638
639    let mut scope = TrainScopeSet::all();
640    let mut optimizer = OptimizerKind::Sgd;
641    let mut hyper = OptimizerHyperParams::default();
642
643    for arg in split_top_level(inner, ',')? {
644        let arg = arg.trim();
645        if arg.is_empty() {
646            continue;
647        }
648        let (k, v) = arg
649            .split_once('=')
650            .with_context(|| format!("invalid train argument '{arg}'"))?;
651        let key = k.trim().to_ascii_lowercase();
652        let val = v.trim();
653        match key.as_str() {
654            "scope" => scope = TrainScopeSet::parse(val, allowed_scopes)?,
655            "opt" | "optimizer" => {
656                optimizer = match val.to_ascii_lowercase().as_str() {
657                    "sgd" => OptimizerKind::Sgd,
658                    "adam" => OptimizerKind::Adam,
659                    other => bail!("unknown optimizer '{other}'"),
660                };
661            }
662            "lr" => {
663                hyper.lr = val
664                    .parse::<f32>()
665                    .with_context(|| format!("invalid lr '{val}'"))?
666                    .max(0.0)
667            }
668            "stride" => {
669                hyper.stride = val
670                    .parse::<usize>()
671                    .with_context(|| format!("invalid stride '{val}'"))?
672                    .max(1)
673            }
674            "bptt" => {
675                hyper.bptt = val
676                    .parse::<usize>()
677                    .with_context(|| format!("invalid bptt '{val}'"))?
678                    .max(1)
679            }
680            "clip" => {
681                hyper.clip = val
682                    .parse::<f32>()
683                    .with_context(|| format!("invalid clip '{val}'"))?
684                    .max(0.0)
685            }
686            "momentum" => {
687                hyper.momentum = val
688                    .parse::<f32>()
689                    .with_context(|| format!("invalid momentum '{val}'"))?
690            }
691            other => bail!("unknown train argument key '{other}'"),
692        }
693    }
694
695    Ok(PolicyAction::Train(TrainAction {
696        scope,
697        optimizer,
698        hyper,
699    }))
700}
701
702fn parse_range(range: &str) -> Result<(PositionExpr, PositionExpr)> {
703    let (start, end) = range
704        .split_once("..")
705        .with_context(|| format!("invalid range '{range}', expected <start>..<end>"))?;
706    Ok((parse_position_expr(start)?, parse_position_expr(end)?))
707}
708
709fn parse_position_expr(token: &str) -> Result<PositionExpr> {
710    let t = token.trim();
711    if let Some(pct_s) = t.strip_suffix('%') {
712        let pct = pct_s
713            .trim()
714            .parse::<f64>()
715            .with_context(|| format!("invalid percent position '{t}'"))?;
716        if !(0.0..=100.0).contains(&pct) {
717            bail!("percent position must be in [0,100], got {pct}");
718        }
719        return Ok(PositionExpr::Percent(pct));
720    }
721    let abs = t
722        .parse::<u64>()
723        .with_context(|| format!("invalid absolute position '{t}'"))?;
724    Ok(PositionExpr::Bytes(abs))
725}
726
727fn resolve_boundary(expr: &PositionExpr, total_symbols: Option<u64>) -> Result<u64> {
728    match expr {
729        PositionExpr::Bytes(v) => Ok(match total_symbols {
730            Some(total) => (*v).min(total),
731            None => *v,
732        }),
733        PositionExpr::Percent(pct) => {
734            let total = total_symbols.ok_or_else(|| {
735                anyhow::anyhow!(
736                    "percent policy boundary requires known total symbol count at runtime"
737                )
738            })?;
739            let resolved = ((total as f64) * (*pct / 100.0)).floor() as u64;
740            Ok(resolved.min(total))
741        }
742    }
743}
744
745fn resolve_span(expr: &PositionExpr, total_symbols: Option<u64>) -> Result<u64> {
746    match expr {
747        PositionExpr::Bytes(v) => Ok(*v),
748        PositionExpr::Percent(pct) => {
749            let total = total_symbols.ok_or_else(|| {
750                anyhow::anyhow!("percent policy span requires known total symbol count at runtime")
751            })?;
752            Ok(((total as f64) * (*pct / 100.0)).floor() as u64)
753        }
754    }
755}
756
757fn split_top_level(input: &str, delim: char) -> Result<Vec<&str>> {
758    let mut parts = Vec::new();
759    let mut depth = 0i32;
760    let mut start = 0usize;
761    for (idx, ch) in input.char_indices() {
762        match ch {
763            '(' => depth += 1,
764            ')' => {
765                depth -= 1;
766                if depth < 0 {
767                    bail!("unbalanced ')' in '{input}'");
768                }
769            }
770            _ if ch == delim && depth == 0 => {
771                parts.push(&input[start..idx]);
772                start = idx + ch.len_utf8();
773            }
774            _ => {}
775        }
776    }
777    if depth != 0 {
778        bail!("unbalanced '(' in '{input}'");
779    }
780    parts.push(&input[start..]);
781    Ok(parts)
782}
783
784fn position_to_string(expr: &PositionExpr) -> String {
785    match expr {
786        PositionExpr::Bytes(v) => v.to_string(),
787        PositionExpr::Percent(p) => {
788            if p.fract() == 0.0 {
789                format!("{}%", *p as i64)
790            } else {
791                format!("{p}%")
792            }
793        }
794    }
795}
796
797fn action_to_string(action: &PolicyAction) -> String {
798    match action {
799        PolicyAction::Infer => "infer".to_string(),
800        PolicyAction::Train(train) => {
801            let opt = match train.optimizer {
802                OptimizerKind::Sgd => "sgd",
803                OptimizerKind::Adam => "adam",
804            };
805            format!(
806                "train(scope={},opt={},lr={},stride={},bptt={},clip={},momentum={})",
807                train.scope.canonical(),
808                opt,
809                train.hyper.lr,
810                train.hyper.stride,
811                train.hyper.bptt,
812                train.hyper.clip,
813                train.hyper.momentum,
814            )
815        }
816    }
817}
818
819#[cfg(test)]
820mod tests {
821    use super::*;
822
823    const RWKV_SCOPES: &[&str] = &[
824        "embed",
825        "pre_norm",
826        "attn_norm",
827        "ffn_norm",
828        "attn",
829        "ffn",
830        "head",
831        "bias",
832        "all",
833        "none",
834    ];
835
836    #[test]
837    fn parse_policy_basic_and_compile() {
838        let p = parse_policy_segment(
839            "policy:schedule=0..10:infer|10..100%:train(scope=head+bias,opt=adam,lr=0.01,stride=2,bptt=4,clip=1.0,momentum=0.95)",
840            RWKV_SCOPES,
841        )
842        .expect("policy");
843        let c = p.compile(Some(100)).expect("compile");
844        assert!(matches!(c.action_at(0), PolicyAction::Infer));
845        match c.action_at(15) {
846            PolicyAction::Train(t) => {
847                assert!(t.scope.contains("head"));
848                assert!(t.scope.contains("bias"));
849                assert_eq!(t.hyper.stride, 2);
850                assert_eq!(t.hyper.bptt, 4);
851            }
852            _ => panic!("expected train"),
853        }
854    }
855
856    #[test]
857    fn parse_repeat_policy() {
858        let p = parse_policy_segment(
859            "schedule=0..100:repeat(0..100,period=10,pattern=3:train(scope=head,opt=sgd,lr=0.1,stride=1,bptt=1,clip=0,momentum=0.9)+7:infer)",
860            RWKV_SCOPES,
861        );
862        assert!(p.is_err());
863
864        let p = parse_policy_segment(
865            "schedule=repeat(0..100,period=10,pattern=3:train(scope=head,opt=sgd,lr=0.1,stride=1,bptt=1,clip=0,momentum=0.9)+7:infer)",
866            RWKV_SCOPES,
867        )
868        .expect("repeat policy");
869        let c = p.compile(Some(100)).expect("compile");
870        assert!(matches!(c.action_at(0), PolicyAction::Train(_)));
871        assert!(matches!(c.action_at(3), PolicyAction::Infer));
872        assert!(matches!(c.action_at(10), PolicyAction::Train(_)));
873    }
874
875    #[test]
876    fn split_method_policy() {
877        let (base, pol) =
878            split_method_policy_segments("cfg:hidden=64;policy:schedule=0..100:infer")
879                .expect("split");
880        assert_eq!(base, "cfg:hidden=64");
881        assert_eq!(pol.as_deref(), Some("schedule=0..100:infer"));
882    }
883}