Problem 27

(Intermediate 馃専馃専) Group the elements of a set into disjoint subsets.

Write a function that generates all the possibilities and returns them in a list.

Note: Originally there were two questions, (a) and (b), but I omitted (a) because it is a special case of (b).

universe u
variable {伪 : Type}

/-- `List` is a Monad.
Since Lean does not use lazy evaluation, no Monad instance of List is defined in the Lean standard library for performance reasons. -/
instance : Monad List.{u} where
  pure := @List.pure
  bind := @List.bind
  map := @List.map

def List.split (n : Nat) (xs : List 伪) : List (List 伪 脳 List 伪) :=
  sorry

#guard [1, 2].split 1 = [([1], [2]), ([2], [1])]

#guard [1, 2, 3].split 2 = [([1, 2], [3]), ([1, 3], [2]), ([2, 3], [1])]

#guard [1, 2, 3].split 3 = [([1, 2, 3], [])]

#guard [1, 2, 3].split 0 = [([], [1, 2, 3])]

def group (pattern : List Nat) (xs : List 伪): List <| List <| List 伪 :=
  sorry

#guard group [1, 2] [1, 2, 3] = [[[1], [2, 3]], [[2], [1, 3]], [[3], [1, 2]]]

#guard group [2, 1] [2, 4, 6] = [[[2, 4], [6]], [[2, 6], [4]], [[4, 6], [2]]]

#guard group [1, 1] [1, 2] = [[[1], [2]], [[2], [1]]]

-- The following codes are for test and you should not edit these.

/-- pattern of 2D `List` -/
def List.pattern (xs : List (List 伪)) : List Nat :=
  xs.map List.length

def Nat.factorial (n : Nat) : Nat :=
  match n with
  | 0 => 1
  | n + 1 => (n + 1) * n.factorial

def runTest [ToString 伪] [BEq 伪] (pattern : List Nat) (xs : List 伪) : IO Unit := do
  if pattern.foldl (路 + 路) 0 != xs.length then
    throw <| IO.userError s!"invalid test case: the sum of pattern should be equal to the length of the input list."

  let result := group pattern xs
  let pattern_check := result.map List.pattern
    |>.all (路 = pattern)
  if not pattern_check then
    throw <| .userError s!"pattern check failed: some elements of result has invalid pattern."

  let distinct_check := result.eraseDups.length = result.length
  if not distinct_check then
    throw <| .userError s!"distinct check failed: the elements of result should be distinct."

  let expected_length := pattern.map Nat.factorial
    |>.foldl (路 * 路) 1
    |> (fun x => xs.length.factorial / x)
  let length_check := result.length = expected_length
  if not length_check then
    throw <| .userError s!"length check failed: the length of result should be equal to {expected_length}."

  IO.println "OK!"

#eval runTest [1, 2, 3] <| List.range 6

#eval runTest [2, 2, 4] <| List.range 8

#eval runTest [2, 3, 4] <| List.range 9