如何停止得到CSP问题的假解决方案?

huangapple go评论84阅读模式
英文:

How can I stop getting false solutions to a CSP problem?

问题

我被给予测试用例,但我的代码生成的解决方案与给定的解决方案不同。而且,每次运行代码时,都会得到不同的解决方案。任何有助于调试的帮助将不胜感激。

以下是约束条件:

  • 红色 - 无约束
  • 黄色 - 等于其相邻所有节点的乘积的最右侧数字
  • 绿色 - 等于其相邻所有节点的总和的最右侧数字
  • 蓝色 - 等于其相邻所有节点的总和的最左侧数字
  • 紫色 - 等于其相邻所有节点的乘积的最左侧数字

每个节点的域为 {1,2,...,9}

我的代码如下:

import random

# ...(省略了代码的其余部分)...

输出:

ERROR: False solution found: [7, 4, 2, 1, 8]

我还尝试过增加步长,但没有取得任何进展。我已经仔细检查了我的约束条件,但似乎仍然存在问题。如果我漏掉了什么,请告诉我。

英文:

I am given test cases but my codes comes up with solutions that aren't part of the solutions given. Also every-time I run the code I get a different solution. Any help debugging this will be greatly appreciated.

Here are the constraints:

  • Red - No constraints
  • Yellow - equals the rightmost digit of of the product of all its neighbors
  • Green - equals the rightmost digit of the sum of all its neighbors
  • Blue - equals the leftmost digit of the sum of all its neighbors
  • Violet - equals the leftmost digit of the product of all of its neighbors

Each node has a domain of {1,2,...,9}

My Code:

import random

def get_neighbors(node, arcs):
    # Returns the neighbors of the given node
    neighbors = []
    for arc in arcs:
        if arc[0] == node:
            neighbors.append(arc[1])
        elif arc[1] == node:
            neighbors.append(arc[0])
    return neighbors

def is_valid_coloring(node, value, node_values, arcs):
    # Checks if the current node coloring satisfies the constraints
    neighbors = get_neighbors(node, arcs)
    color = node_values[node]
    
    if color == 'Y':
        product = 1
        for neighbor in neighbors:
            product *= node_values[neighbor]
        return value == product % 10
        
    elif color == 'G':
        s = sum(node_values[neighbor] for neighbor in neighbors)
        return value == s % 10

    elif color == 'B':
        sum = 0
        for neighbor in neighbors:
            sum += node_values[neighbor]
        return value == sum % 10
        
    elif color == 'V':
        product = 1
        for neighbor in neighbors:
            product *= node_values[neighbor]
        return value == product % 10
    else:
        return True

def select_unassigned_variable(node_values, nodes, arcs):
    """
    Returns an unassigned node that has the most conflicts with its neighbors.
    """
    unassigned_nodes = [i for i, val in enumerate(node_values) if val == 0]
    max_conflicts = -1
    max_conflict_nodes = []
    for node in unassigned_nodes:
        neighbors = get_neighbors(node, arcs)
        node_conflicts = 0
        for neighbor in neighbors:
            if node_values[neighbor] != 0 and not is_valid_coloring(neighbor, node_values[neighbor], node_values, arcs):
                node_conflicts += 1
        if node_conflicts > max_conflicts:
            max_conflicts = node_conflicts
            max_conflict_nodes = [node]
        elif node_conflicts == max_conflicts:
            max_conflict_nodes.append(node)
    if len(max_conflict_nodes) == 0:
        return None
    return random.choice(max_conflict_nodes)


def get_conflicts(node_values, node, arcs, nodes):
    conflicts = 0
    node_idx = node
    for arc in arcs:
        if node_idx == arc[1]:
            if node_values[node_idx] == node_values[arc[0]]:
                conflicts += 1
        if node_idx == arc[0]:
            if node_values[node_idx] == node_values[arc[1]]:
                conflicts += 1
    return conflicts

def min_conflicts(node_values, nodes, arcs, max_steps):
    # Solves the csp using the mini conflicts algorithm
    for step in range(max_steps):
        unassigned_node = select_unassigned_variable(node_values, nodes, arcs)
        if unassigned_node is None:
            return node_values
        domain = [i for i in range(1, 10)]
        conflicts = [get_conflicts(node_values, unassigned_node, arcs, nodes)]
        min_conflicts = float('inf')
        min_conflict_values = []
        for value in domain:
            new_node_values = node_values.copy()
            new_node_values[unassigned_node] = value
            if is_valid_coloring(unassigned_node, value, new_node_values, arcs):
                num_conflicts = get_conflicts(new_node_values, unassigned_node, arcs, nodes)
                if num_conflicts < min_conflicts:
                    min_conflicts = num_conflicts
                    min_conflict_values = [value]
                elif num_conflicts == min_conflicts:
                    min_conflict_values.append(value)
        if min_conflict_values:
            new_value = random.choice(min_conflict_values)
            node_values[unassigned_node] = new_value
        else:
            # If there are no values that result in a minimum number of conflicts,
            # choose a random value from the domain
            new_value = random.choice(domain)
            node_values[unassigned_node] = new_value
        # If the new node values lead to an invalid coloring, try again with a different value
        if not is_valid_coloring(unassigned_node, new_value, node_values, arcs):
            node_values[unassigned_node] = random.choice([x for x in domain if x != new_value])
    return None


def solve_csp(nodes, arcs, max_steps):
    # Convert nodes to strings
    nodes = [str(node) for node in nodes]
    node_values = [0] * len(nodes)
    return min_conflicts(node_values, nodes, arcs, max_steps)



def main():
    # test Case 1

    nodes = 'YGVRB'
    arcs = [(0,1), (0,2), (1,2), (1,3), (1,4), (2,3), (2,4)]
    max_steps = 1000

    for _ in range(max_steps):
        sol = solve_csp(nodes, arcs, max_steps)
        if sol != []:
            break
            
    all_solutions = [[1, 1, 1, 7, 2],[2, 1, 2, 4, 3],[2, 6, 7, 6, 1],[2, 8, 9, 6, 1],
                    [3, 3, 1, 5, 4],[6, 2, 8, 7, 1],[6, 7, 8, 2, 1],[6, 9, 4, 8, 1]]

    if sol == []:
        print('No solution')
    else:
        if sol in all_solutions:
            print('Solution found:', sol)
        else:
            print('ERROR: False solution found:', sol)


if __name__ == '__main__':
    main()

Output:

ERROR: False solution found: [7, 4, 2, 1, 8]

I have also tried changing to a bigger step size and have not had any luck. I have double checked to make sure my constraints were accurate. Please let me know if there is something I have missed

答案1

得分: 2

以下是用于启动故障排查的一些内容。

下面的代码使用 doctest 库向你的 get_neighbors 函数添加了一些测试。该函数看起来工作正常,因此错误很可能在其他地方。如果你继续以这种方式向所有函数添加简单测试,并将较大的函数分成更多部分,最终你将找到代码中的错误:

import random
import doctest

def get_neighbors(node, arcs):
    """
    返回给定节点的邻居节点
    
    >>> get_neighbors(0, [(0, 1), (0, 2), (1, 2), (2, 3), (4, 0)])
    [1, 2, 4]
    >>> get_neighbors(1, [(0, 1), (0, 2), (1, 2), (2, 3), (4, 0)])
    [0, 2]
    """
    neighbors = []
    for arc in arcs:
        if arc[0] == node:
            neighbors.append(arc[1])
        elif arc[1] == node:
            neighbors.append(arc[0])
    return neighbors

希望这能帮到你!

英文:

Here is something to get you started on your bug-hunt.

The code below adds tests to your get_neighbors function using the doctest library. That function looks to be working correctly so the bug is most likely elsewhere. If you continue in this fashion adding simple tests to all of your functions and dividing the larger functions in more pieces you will eventually find the bug(s) in your code:

import random
import doctest

def get_neighbors(node, arcs):
    """
    Returns the neighbors of the given node
    
    >>> get_neighbors(0, [(0, 1), (0, 2), (1, 2), (2, 3), (4, 0)])
    [1, 2, 4]
    >>> get_neighbors(1, [(0, 1), (0, 2), (1, 2), (2, 3), (4, 0)])
    [0, 2]
    """
    neighbors = []
    for arc in arcs:
        if arc[0] == node:
            neighbors.append(arc[1])
        elif arc[1] == node:
            neighbors.append(arc[0])
    return neighbors

def is_valid_coloring(node, value, node_values, arcs):
    # Checks if the current node coloring satisfies the constraints
    neighbors = get_neighbors(node, arcs)
    color = node_values[node]
    
    if color == 'Y':
        product = 1
        for neighbor in neighbors:
            product *= node_values[neighbor]
        return value == product % 10
        
    elif color == 'G':
        s = sum(node_values[neighbor] for neighbor in neighbors)
        return value == s % 10

    elif color == 'B':
        sum = 0
        for neighbor in neighbors:
            sum += node_values[neighbor]
        return value == sum % 10
        
    elif color == 'V':
        product = 1
        for neighbor in neighbors:
            product *= node_values[neighbor]
        return value == product % 10
    else:
        return True

def select_unassigned_variable(node_values, nodes, arcs):
    """
    Returns an unassigned node that has the most conflicts with its neighbors.
    """
    unassigned_nodes = [i for i, val in enumerate(node_values) if val == 0]
    max_conflicts = -1
    max_conflict_nodes = []
    for node in unassigned_nodes:
        neighbors = get_neighbors(node, arcs)
        node_conflicts = 0
        for neighbor in neighbors:
            if node_values[neighbor] != 0 and not is_valid_coloring(neighbor, node_values[neighbor], node_values, arcs):
                node_conflicts += 1
        if node_conflicts > max_conflicts:
            max_conflicts = node_conflicts
            max_conflict_nodes = [node]
        elif node_conflicts == max_conflicts:
            max_conflict_nodes.append(node)
    if len(max_conflict_nodes) == 0:
        return None
    return random.choice(max_conflict_nodes)


def get_conflicts(node_values, node, arcs, nodes):
    conflicts = 0
    node_idx = node
    for arc in arcs:
        if node_idx == arc[1]:
            if node_values[node_idx] == node_values[arc[0]]:
                conflicts += 1
        if node_idx == arc[0]:
            if node_values[node_idx] == node_values[arc[1]]:
                conflicts += 1
    return conflicts

def min_conflicts(node_values, nodes, arcs, max_steps):
    # Solves the csp using the mini conflicts algorithm
    for step in range(max_steps):
        unassigned_node = select_unassigned_variable(node_values, nodes, arcs)
        if unassigned_node is None:
            return node_values
        domain = [i for i in range(1, 10)]
        conflicts = [get_conflicts(node_values, unassigned_node, arcs, nodes)]
        min_conflicts = float('inf')
        min_conflict_values = []
        for value in domain:
            new_node_values = node_values.copy()
            new_node_values[unassigned_node] = value
            if is_valid_coloring(unassigned_node, value, new_node_values, arcs):
                num_conflicts = get_conflicts(new_node_values, unassigned_node, arcs, nodes)
                if num_conflicts < min_conflicts:
                    min_conflicts = num_conflicts
                    min_conflict_values = [value]
                elif num_conflicts == min_conflicts:
                    min_conflict_values.append(value)
        if min_conflict_values:
            new_value = random.choice(min_conflict_values)
            node_values[unassigned_node] = new_value
        else:
            # If there are no values that result in a minimum number of conflicts,
            # choose a random value from the domain
            new_value = random.choice(domain)
            node_values[unassigned_node] = new_value
        # If the new node values lead to an invalid coloring, try again with a different value
        if not is_valid_coloring(unassigned_node, new_value, node_values, arcs):
            node_values[unassigned_node] = random.choice([x for x in domain if x != new_value])
    return None


def solve_csp(nodes, arcs, max_steps):
    # Convert nodes to strings
    nodes = [str(node) for node in nodes]
    node_values = [0] * len(nodes)
    return min_conflicts(node_values, nodes, arcs, max_steps)



def main():
    # test Case 1

    nodes = 'YGVRB'
    arcs = [(0,1), (0,2), (1,2), (1,3), (1,4), (2,3), (2,4)]
    max_steps = 1000

    for _ in range(max_steps):
        sol = solve_csp(nodes, arcs, max_steps)
        if sol != []:
            break
            
    all_solutions = [[1, 1, 1, 7, 2],[2, 1, 2, 4, 3],[2, 6, 7, 6, 1],[2, 8, 9, 6, 1],
                    [3, 3, 1, 5, 4],[6, 2, 8, 7, 1],[6, 7, 8, 2, 1],[6, 9, 4, 8, 1]]

    if sol == []:
        print('No solution')
    else:
        if sol in all_solutions:
            print('Solution found:', sol)
        else:
            print('ERROR: False solution found:', sol)


if __name__ == '__main__':
    doctest.testmod(verbose=True)
    main()

答案2

得分: 0

另一种调试基于图的代码的可能性是使用networkx包来可视化你的图,以使你的解决方案中的不一致性更直观。

您可以为不同的节点选择不同的颜色,这将帮助您区分项目中不同种类的节点:

G.add_nodes_from([
    (4, {"color": "red"}),
    (5, {"color": "green"}),
])
英文:

Another possibility to debug graph based code is using the networkx package to visualize your graph visually to make the inconsistencies of your solution more intuitive.

You can choose different colors for different nodes, that will help you differentiate the different kinds of nodes in your project:

G.add_nodes_from([
    (4, {"color": "red"}),
    (5, {"color": "green"}),
])

huangapple
  • 本文由 发表于 2023年2月18日 08:50:41
  • 转载请务必保留本文链接:https://go.coder-hub.com/75490468.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定