/*
 * Decompiled with CFR 0.152.
 */
package mb.statix.spec;

import com.google.common.collect.HashMultimap;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.ImmutableSetMultimap;
import com.google.common.collect.Iterables;
import com.google.common.collect.Lists;
import com.google.common.collect.MoreCollectors;
import com.google.common.collect.Multimap;
import com.google.common.collect.SetMultimap;
import com.google.common.collect.Sets;
import com.google.common.collect.Streams;
import io.usethesource.capsule.Set;
import java.util.ArrayList;
import java.util.Collection;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import javax.annotation.Nullable;
import mb.nabl2.terms.ITerm;
import mb.nabl2.terms.ITermVar;
import mb.nabl2.terms.build.TermBuild;
import mb.nabl2.terms.matching.Pattern;
import mb.nabl2.terms.matching.TermPattern;
import mb.nabl2.terms.substitution.IRenaming;
import mb.nabl2.terms.unification.OccursException;
import mb.nabl2.terms.unification.u.IUnifier;
import mb.nabl2.terms.unification.u.PersistentUnifier;
import mb.nabl2.terms.unification.ud.Diseq;
import mb.nabl2.terms.unification.ud.IUniDisunifier;
import mb.nabl2.terms.unification.ud.PersistentUniDisunifier;
import mb.nabl2.util.Tuple2;
import mb.nabl2.util.Tuple3;
import mb.statix.constraints.CConj;
import mb.statix.constraints.CEqual;
import mb.statix.constraints.CUser;
import mb.statix.constraints.Constraints;
import mb.statix.solver.IConstraint;
import mb.statix.solver.IState;
import mb.statix.solver.StateUtil;
import mb.statix.spec.ApplyResult;
import mb.statix.spec.FreshVars;
import mb.statix.spec.Rule;
import mb.statix.spec.RuleSet;
import org.metaborg.util.functions.Action1;
import org.metaborg.util.functions.Function1;
import org.metaborg.util.functions.PartialFunction1;
import org.metaborg.util.functions.Predicate1;
import org.metaborg.util.functions.Predicate2;

public class RuleUtil {
    public static Optional<Optional<Tuple2<Rule, ApplyResult>>> applyOrderedOne(IState.Immutable state, List<Rule> rules, List<? extends ITerm> args, @Nullable IConstraint cause) {
        return RuleUtil.applyOrdered(state, rules, args, cause, true).map(rs -> (Optional)rs.stream().collect(MoreCollectors.toOptional()));
    }

    public static List<Tuple2<Rule, ApplyResult>> applyOrderedAll(IState.Immutable state, List<Rule> rules, List<? extends ITerm> args, @Nullable IConstraint cause) {
        return RuleUtil.applyOrdered(state, rules, args, cause, false).get();
    }

    private static Optional<List<Tuple2<Rule, ApplyResult>>> applyOrdered(IState.Immutable state, List<Rule> rules, List<? extends ITerm> args, @Nullable IConstraint cause, boolean onlyOne) {
        ImmutableList.Builder results = ImmutableList.builder();
        AtomicBoolean foundOne = new AtomicBoolean(false);
        for (Rule rule : rules) {
            ApplyResult applyResult = RuleUtil.apply(state, rule, args, cause).orElse(null);
            if (applyResult == null) continue;
            if (onlyOne && foundOne.getAndSet(true)) {
                return Optional.empty();
            }
            results.add(Tuple2.of(rule, applyResult));
            Tuple3 guard = applyResult.guard().map(Diseq::toTuple).orElse(null);
            if (guard == null) break;
            Optional<IUniDisunifier.Immutable> newUnifier = state.unifier().disunify((Iterable)guard._1(), (ITerm)guard._2(), (ITerm)guard._3()).map(IUniDisunifier.Result::unifier);
            if (!newUnifier.isPresent()) {
                throw new IllegalStateException("Unexpected incompatible guard.");
            }
            state = state.withUnifier(newUnifier.get());
        }
        return Optional.of(results.build());
    }

    public static Optional<ApplyResult> apply(IState.Immutable state, Rule rule, List<? extends ITerm> args, @Nullable IConstraint cause) {
        IState.Transient newState = state.melt();
        Set.Transient _universalVars = Set.Transient.of();
        Function1<Optional<ITermVar>, ITermVar> fresh = v -> {
            ITermVar f = v.map(newState::freshVar).orElseGet(newState::freshWld);
            _universalVars.__insert((Object)f);
            return f;
        };
        return TermPattern.P.matchWithEqs((Iterable<Pattern>)rule.params(), (Iterable<? extends ITerm>)args, state.unifier(), fresh).flatMap(matchResult -> {
            ApplyResult applyResult;
            Set.Immutable universalVars = _universalVars.freeze();
            Sets.SetView constrainedVars = Sets.difference(matchResult.constrainedVars(), (Set)universalVars);
            IConstraint newConstraint = rule.body().apply(matchResult.substitution()).withCause(cause);
            if (constrainedVars.isEmpty()) {
                applyResult = ApplyResult.of(newState.freeze(), PersistentUnifier.Immutable.of(), Optional.empty(), newConstraint);
            } else {
                IUniDisunifier.Result unifyResult;
                try {
                    unifyResult = state.unifier().unify(matchResult.equalities()).orElse(null);
                    if (unifyResult == null) {
                        return Optional.empty();
                    }
                }
                catch (OccursException e) {
                    return Optional.empty();
                }
                IUniDisunifier.Immutable newUnifier = unifyResult.unifier();
                IUnifier.Immutable diff = (IUnifier.Immutable)unifyResult.result();
                IUnifier.Immutable guard = diff.retainAll((Iterable<ITermVar>)constrainedVars).unifier();
                if (guard.isEmpty()) {
                    throw new IllegalStateException("Guard not expected to be empty here.");
                }
                Diseq diseq = Diseq.of((Iterable<ITermVar>)universalVars, guard);
                IState.Immutable resultState = newState.freeze().withUnifier(newUnifier);
                applyResult = ApplyResult.of(resultState, diff, Optional.of(diseq), newConstraint);
            }
            return Optional.of(applyResult);
        });
    }

    public static List<Tuple2<Rule, ApplyResult>> applyAll(IState.Immutable state, Collection<Rule> rules, List<? extends ITerm> args, @Nullable IConstraint cause) {
        return (List)rules.stream().flatMap(rule -> Streams.stream(RuleUtil.apply(state, rule, args, cause)).map(result -> Tuple2.of(rule, result))).collect(ImmutableList.toImmutableList());
    }

    public static ImmutableSet<Rule> computeOrderIndependentRules(List<Rule> rules) {
        ArrayList guards = Lists.newArrayList();
        return (ImmutableSet)rules.stream().flatMap(r -> {
            IUniDisunifier.Transient diseqs = PersistentUniDisunifier.Immutable.of().melt();
            FreshVars fresh = new FreshVars(r.varSet());
            List paramPatterns = (List)r.params().stream().map(p -> p.eliminateWld(() -> fresh.fresh("_"))).collect(ImmutableList.toImmutableList());
            fresh.fix();
            Pattern paramsPattern = TermPattern.P.newTuple(paramPatterns);
            Tuple2<ITerm, List<Tuple2<ITermVar, ITerm>>> p_eqs = paramsPattern.asTerm(Optional::get);
            try {
                if (!diseqs.unify((Iterable<? extends Map.Entry<? extends ITerm, ? extends ITerm>>)p_eqs._2()).isPresent()) {
                    return Stream.empty();
                }
            }
            catch (OccursException e) {
                return Stream.empty();
            }
            boolean guardsOk = guards.stream().allMatch(g -> {
                IRenaming swap = fresh.fresh(g.getVars());
                Pattern g1 = g.eliminateWld(() -> fresh.fresh("_"));
                Tuple2<ITerm, List<Tuple2<ITermVar, ITerm>>> t_eqs = g1.apply(swap).asTerm(Optional::get);
                List leftEqs = (List)t_eqs._2().stream().map(Tuple2::_1).collect(ImmutableList.toImmutableList());
                List rightEqs = (List)t_eqs._2().stream().map(Tuple2::_2).collect(ImmutableList.toImmutableList());
                ITerm left = TermBuild.B.newTuple((ITerm)p_eqs._1(), TermBuild.B.newTuple(leftEqs));
                ITerm right = TermBuild.B.newTuple(t_eqs._1(), TermBuild.B.newTuple(rightEqs));
                Set.Immutable<ITermVar> universals = fresh.reset();
                return diseqs.disunify((Iterable<ITermVar>)universals, left, right).isPresent();
            });
            if (!guardsOk) {
                return Stream.empty();
            }
            guards.add(paramsPattern);
            IConstraint body = Constraints.conjoin(StateUtil.asInequalities(diseqs), r.body());
            return Stream.of(r.withParams(paramPatterns).withBody(body));
        }).collect(ImmutableSet.toImmutableSet());
    }

    public static Optional<Rule> inline(Rule rule, int ith, Rule into) {
        FreshVars fresh = new FreshVars(into.varSet());
        AtomicInteger i = new AtomicInteger(0);
        IConstraint newBody = Constraints.map(c -> {
            if (!(c instanceof CUser)) {
                return c;
            }
            CUser constraint = (CUser)c;
            if (!constraint.name().equals(rule.name())) {
                return c;
            }
            if (i.getAndIncrement() != ith) {
                return c;
            }
            return RuleUtil.applyToConstraint(fresh, rule, constraint.args());
        }, false).apply(into.body());
        if (i.get() <= ith) {
            return Optional.empty();
        }
        return Optional.of(into.withLabel("").withBody(newBody));
    }

    private static IConstraint applyToConstraint(FreshVars fresh, Rule rule, List<? extends ITerm> args) {
        IRenaming swap = fresh.fresh(rule.paramVars());
        Pattern rulePatterns = TermPattern.P.newTuple((Iterable<? extends Pattern>)rule.params()).eliminateWld(() -> fresh.fresh("_"));
        Tuple2<ITerm, List<Tuple2<ITermVar, ITerm>>> p_eqs = rulePatterns.asTerm(v -> (ITermVar)v.get());
        ITerm p = swap.apply(p_eqs._1());
        ITerm t = TermBuild.B.newTuple(args);
        CEqual eq = new CEqual(t, p);
        Set.Immutable<ITermVar> newVars = fresh.reset();
        IConstraint newConstraint = Constraints.exists(newVars, new CConj(eq, rule.body().apply(swap)));
        return newConstraint;
    }

    public static Optional<Rule> simplify(Rule rule) {
        ArrayList constraints = Lists.newArrayList();
        IUniDisunifier.Transient unifier = PersistentUniDisunifier.Immutable.of().melt();
        FreshVars fresh = new FreshVars(rule.paramVars());
        LinkedList worklist = Lists.newLinkedList();
        worklist.push(rule.body());
        while (!worklist.isEmpty()) {
            IConstraint constraint = (IConstraint)worklist.removeLast();
            boolean okay = constraint.match(Constraints.cases(c -> {
                constraints.add(c);
                return true;
            }, conj -> {
                Constraints.disjoin(conj).forEach(worklist::addLast);
                return true;
            }, equal -> {
                try {
                    return unifier.unify(equal.term1(), equal.term2()).isPresent();
                }
                catch (OccursException e) {
                    return false;
                }
            }, exists -> {
                IRenaming renaming = fresh.fresh(exists.vars());
                worklist.addLast(exists.constraint().apply(renaming));
                return true;
            }, c -> {
                constraints.add(c);
                return true;
            }, inequal -> unifier.disunify(inequal.universals(), inequal.term1(), inequal.term2()).isPresent(), c -> {
                constraints.add(c);
                return true;
            }, c -> {
                constraints.add(c);
                return true;
            }, c -> {
                constraints.add(c);
                return true;
            }, c -> {
                constraints.add(c);
                return true;
            }, c -> {
                constraints.add(c);
                return true;
            }, c -> {
                constraints.add(c);
                return true;
            }, c -> {
                constraints.add(c);
                return true;
            }, c -> {
                constraints.add(c);
                return true;
            }, c -> {
                constraints.add(c);
                return true;
            }));
            if (okay) continue;
            return Optional.empty();
        }
        Set.Immutable<ITermVar> newVars = fresh.reset();
        IConstraint newBody = Constraints.exists(newVars, Constraints.conjoin(Iterables.concat(StateUtil.asConstraint(unifier), (Iterable)constraints)));
        Rule newRule = Rule.builder().from(rule).body(newBody).build();
        return Optional.of(newRule);
    }

    public static SetMultimap<String, Rule> makeFragments(RuleSet rules, Predicate1<String> includePredicate, Predicate2<String, String> includeRule, int generations) {
        HashMultimap fragments = HashMultimap.create();
        HashMultimap newRules = HashMultimap.create();
        for (String ruleName : rules.getRuleNames()) {
            if (!includePredicate.test(ruleName)) continue;
            for (Object r : rules.getOrderIndependentRules(ruleName)) {
                if (!includeRule.test(((Rule)r).name(), ((Rule)r).label())) continue;
                newRules.put((Object)ruleName, r);
            }
        }
        PartialFunction1 expandable = arg_0 -> RuleUtil.lambda$37((SetMultimap)newRules, arg_0);
        for (Map.Entry e : newRules.entries()) {
            if (!Constraints.collectBase(expandable, false).apply(((Rule)e.getValue()).body()).isEmpty()) continue;
            fragments.put((Object)((String)e.getKey()), (Object)((Rule)e.getValue()));
        }
        fragments.forEach((arg_0, arg_1) -> ((SetMultimap)newRules).remove(arg_0, arg_1));
        int g = 0;
        while (g < generations) {
            HashMultimap generation = HashMultimap.create();
            for (Map.Entry e : newRules.entries()) {
                String name = (String)e.getKey();
                Rule r = (Rule)e.getValue();
                FreshVars fresh = new FreshVars(r.varSet());
                List cs = Constraints.flatMap(arg_0 -> RuleUtil.lambda$39(fresh, expandable, (SetMultimap)fragments, arg_0), false).apply(r.body()).collect(Collectors.toList());
                for (IConstraint c : cs) {
                    Rule f = r.withLabel("").withBody(c);
                    Optional<Rule> sf = RuleUtil.simplify(f);
                    if (!sf.isPresent()) continue;
                    generation.put((Object)name, (Object)sf.get());
                }
            }
            fragments.putAll((Multimap)generation);
            ++g;
        }
        return ImmutableSetMultimap.copyOf((Multimap)fragments);
    }

    public static void freeVars(Rule rule, Action1<ITermVar> onVar) {
        Set<ITermVar> paramVars = rule.paramVars();
        Constraints.freeVars(rule.body(), (ITermVar v) -> {
            if (!paramVars.contains(v)) {
                onVar.apply((ITermVar)v);
            }
        });
    }

    public static void vars(Rule rule, Action1<ITermVar> onVar) {
        Constraints.vars(rule.body(), onVar);
    }

    private static /* synthetic */ Optional lambda$37(SetMultimap setMultimap, IConstraint c) {
        return c instanceof CUser && setMultimap.containsKey((Object)((CUser)c).name()) ? Optional.of((CUser)c) : Optional.empty();
    }

    private static /* synthetic */ Stream lambda$39(FreshVars freshVars, PartialFunction1 partialFunction1, SetMultimap setMultimap, IConstraint c) {
        Optional u = (Optional)partialFunction1.apply(c);
        if (u.isPresent()) {
            return setMultimap.get((Object)((CUser)u.get()).name()).stream().map(f -> RuleUtil.applyToConstraint(freshVars, f, ((CUser)u.get()).args()));
        }
        return Stream.of(c);
    }
}

