[230216] Chat GPT로 코테 풀이 개선하기 (백준 14500 python)
과거 https://cjk09083.tistory.com/28 에서 14500 문제를 풀이하였었는데
Chat GPT를 사용해본 김에 이를 활용해서 풀이를 개선시켜 보았다.
우선 기준이될 하드코딩으로 작성하였던 해답)
import sys
input=sys.stdin.readline
def getBox():
board = []
for _ in range(3):
board.append([0]*(m+6))
for _ in range(n):
tmp = [0,0,0]
tmp += list(map(int, input().split()))
tmp += [0,0,0]
board.append(tmp)
for _ in range(3):
board.append([0]*(m+6))
return board
def sol(i,j):
cnt = 0
# 1 직선 (세로놓기)
cnt = max(cnt, b[i][j] + b[i+1][j] + b[i+2][j] + b[i+3][j])
# 2 직선 (가로놓기)
cnt = max(cnt, b[i][j] + b[i][j+1] + b[i][j+2] + b[i][j+3])
# 3 네모
cnt = max(cnt, b[i][j] + b[i][j+1] + b[i+1][j+1] + b[i+1][j])
# 4 ㄴ // 왼상단 시작 오른 하단 종료. (세로가 큰 ㄴ)
cnt = max(cnt, b[i][j] + b[i+1][j] + b[i+2][j] + b[i+2][j+1])
# 5 ㄴ // 오른 상단 시작 왼 하단 종료. (4 좌우 대칭)
cnt = max(cnt, b[i][j+1] + b[i+1][j+1] + b[i+2][j+1] + b[i+2][j])
# 6 ㄴ // 왼하단 시작 오른 상단 종료. (가로가 긴 ㄱ)
cnt = max(cnt, b[i][j] + b[i][j+1] + b[i][j+2] + b[i+1][j+2])
# 7 ㄴ // 왼하단 시작 오른 상단 종료. (6 대칭)
cnt = max(cnt, b[i+1][j] + b[i][j] + b[i][j+1] + b[i][j+2])
# 8 ㄴ // 왼상단 시작 오른 하단 종료. (세로가 긴 ㄱ)
cnt = max(cnt, b[i][j] + b[i][j+1] + b[i+1][j+1] + b[i+2][j+1])
# 9 ㄴ // 왼상단 시작 오른 하단 종료. (8 대칭)
cnt = max(cnt, b[i][j+1] + b[i][j] + b[i+1][j] + b[i+2][j])
# 10 ㄴ // 오른 상단 시작 왼 하단 종료 (가로가 긴 ㄴ)
cnt = max(cnt, b[i][j] + b[i+1][j] + b[i+1][j+1] + b[i+1][j+2])
# 11 ㄴ // 오른 상단 시작 왼 하단 종료 (10 대칭)
cnt = max(cnt, b[i+1][j] + b[i+1][j+1] + b[i+1][j+2] + b[i][j+2])
# 12 ㄴㄱ // 왼 상단 시작 오른 하단 종료 (ㄴ 밑에 ㄱ)
cnt = max(cnt, b[i][j] + b[i+1][j] + b[i+1][j+1] + b[i+2][j+1])
# 13 ㄴㄱ // 왼 하단 시작 오른 상단 종료 (ㄴ 밑에 ㄱ 좌우 반전)
cnt = max(cnt, b[i][j+1] + b[i+1][j+1] + b[i+1][j] + b[i+2][j])
# 14 ㄴㄱ // 왼 상단 시작 오른 하단 종료 (ㄱ 우측에 ㄴ)
cnt = max(cnt, b[i][j] + b[i][j+1] + b[i+1][j+1] + b[i+1][j+2])
# 15 ㄴㄱ // 오른 상단 시작 왼 하단 종료 (ㄱ 우측에 ㄴ 상하 반전)
cnt = max(cnt, b[i+1][j] + b[i+1][j+1] + b[i][j+1] + b[i][j+2])
# 16 ㅗ // ㅜ
cnt = max(cnt, b[i][j] + b[i][j+1] + b[i+1][j+1] + b[i][j+2])
# 17 ㅗ // ㅓ
cnt = max(cnt, b[i][j+1] + b[i+1][j+1] + b[i+2][j+1] + b[i+1][j])
# 18 ㅗ // ㅗ
cnt = max(cnt, b[i+1][j] + b[i+1][j+1] + b[i][j+1] + b[i+1][j+2])
# 19 ㅗ // ㅏ
cnt = max(cnt, b[i][j] + b[i+1][j] + b[i+1][j+1] + b[i+2][j])
return cnt
if __name__ == "__main__":
n, m = map(int,(input().split()))
b = getBox()
# print(*b, sep = "\n")
result = 0
for i in range(n):
for j in range(m):
result = max(result,sol(i+3,j+3))
print(result)
이를 Chat GPT에게 분석해달라 하였더니 아래와 같이 대답하였다.
내가 N+6 사이즈의 리스트를 만든이유와 sol 함수에서 동작하는 원리를 잘 설명해준다.
이제 아래와 같이 풀이를 개선해 보았다.
Q1) 좀더 속도를 빠르게 개선할 수 있을까?
A1)
각 개선 방법을 적용한 코드를 물어보았다.
Q2) 중복 계산을 피하도록 수정해줘
A2)
def sol(i,j):
cnt = 0
patterns = [
[0, 0], # padding
[0, 1, 2, 3], # case 1
[0, m, m*2, m*3], # case 2
[0, m, m+1, m*2+1], # case 3
[0, m, m*2, m*2+1], # case 4
[1, m+1, m*2+1, m*2], # case 5
[0, 1, m+1, m*2+1], # case 6
[1, m+1, m, m*2], # case 7
[0, 1, m+1, m*2+1], # case 8
[1, m+1, m+2, m+3], # case 9
[0, 1, 2, m+2], # case 10
[0, m, m+1, m*2+1], # case 11
[0, m, m+1, m*2], # case 12
[1, m+1, m, m-1], # case 13
[0, 1, m, m+1], # case 14
[1, m+1, m+2, m*2+1], # case 15
[0, 1, 2, m+1], # case 16
[1, m, m+1, m*2+1], # case 17
[1, m, m+1, m+2], # case 18
[0, m, m+1, m*2], # case 19
]
for pattern in patterns:
psum = sum([b[i + (p // m)][j + (p % m)] for p in pattern])
cnt = max(cnt, psum)
return cnt
실행 결과 인덱싱 에러가 발생해 다시 질문했다.
해당 부분을 수정해도 여전히 오류가 발생해 계속 오류 수정을 부탁했지만 점점 정답과 멀어지게 수정되었다.
따라서 아예 개선 방법을 지정해주도록 했다.
Q3) dfs를 사용하도록 코드를 수정해줘 (처음 코드 입력)
A3)
import sys
input=sys.stdin.readline
def getBox():
board = []
for _ in range(3):
board.append([0] * (m+6))
for _ in range(n):
tmp = [0,0,0]
tmp += list(map(int, input().split()))
tmp += [0,0,0]
board.append(tmp)
for _ in range(3):
board.append([0] * (m+6))
return board
def dfs(i, j, cnt, depth):
if depth == 4:
global result
result = max(result, cnt)
return
for di, dj in [(0, 1), (1, 0), (0, -1), (-1, 0)]:
ni, nj = i + di, j + dj
if visited[ni][nj]:
continue
visited[ni][nj] = True
dfs(ni, nj, cnt + b[ni][nj], depth + 1)
visited[ni][nj] = False
def sol(i, j):
global result
visited[i][j] = True
dfs(i, j, b[i][j], 1)
visited[i][j] = False
if __name__ == "__main__":
n, m = map(int,(input().split()))
b = getBox()
result = 0
visited = [[False] * (m+6) for _ in range(n+6)]
for i in range(n):
for j in range(m):
sol(i+3, j+3)
print(result)
수정된 코드를 확인 결과 T자 모양에 대해서는 계산이 되지않아 추가 질문을 하였다.
Q3) T 모양에 대해서는 탐색이 안되는것 같아.
A3)
수정후에 입력
4 10
1 2 1 2 1 2 1 2 1 2
2 1 2 1 2 1 2 1 2 1
1 2 1 2 1 2 1 2 1 2
2 1 2 1 2 1 2 1 2 1
에 대하여 결과가 8이 출력돼 다시 질문했다
Q4) 입력 ~~~ 에 대하여 결과가 8이 출력돼. 정답은 7이야
A4)
def getBox():
board = []
for _ in range(3):
board.append([0]*(m+6))
for _ in range(n):
tmp = [0]*3 # 수정된 부분
tmp += list(map(int, input().split()))
tmp += [0]*3 # 수정된 부분
board.append(tmp)
for _ in range(3):
board.append([0]*(m+6))
return board
위부분을 수정해도 결과는 마찬가지라 다시 질문하였다.
Q5) 여전히 결과가 8이 출력돼
A5)
어..?
아예 코테 문제를 입력해주었다.
Q6) 폴리오미노란 크기가 1×1인 정사각형을 여러 개 이어서 붙인 도형이며, 다음과 같은 조건을 만족해야 한다. 정사각형은 서로 겹치면 안 된다. 도형은 모두 연결되어 있어야 한다. 정사각형의 변끼리 연결되어 있어야 한다. 즉, 꼭짓점과 꼭짓점만 맞닿아 있으면 안 된다. 정사각형 4개를 이어 붙인 폴리오미노는 테트로미노라고 하며, 5가지가 있다. 아름이는 크기가 N×M인 종이 위에 테트로미노 하나를 놓으려고 한다. 종이는 1×1 크기의 칸으로 나누어져 있으며, 각각의 칸에는 정수가 하나 쓰여 있다. 테트로미노 하나를 적절히 놓아서 테트로미노가 놓인 칸에 쓰여 있는 수들의 합을 최대로 하는 프로그램을 작성하시오. 테트로미노는 반드시 한 정사각형이 정확히 하나의 칸을 포함하도록 놓아야 하며, 회전이나 대칭을 시켜도 된다.
A6)
import sys
# 가능한 모든 테트로미노 모양을 미리 정의합니다.
tetrominos = [
[(0, 0)],
[(0, 0), (0, 1)],
[(0, 0), (1, 0)],
[(0, 0), (0, 1), (0, 2)],
[(0, 0), (1, 0), (2, 0)],
[(0, 0), (0, 1), (1, 1), (1, 0)],
[(0, 0), (1, 0), (1, 1), (2, 1)],
[(0, 0), (1, 0), (2, 0), (2, 1)],
[(0, 0), (0, 1), (0, 2), (1, 2)],
[(0, 0), (0, 1), (1, 1), (2, 1)],
[(0, 0), (1, 0), (1, 1), (2, 0)],
[(0, 0), (1, 0), (1, 1), (2, 1)],
[(0, 0), (0, 1), (1, 0), (1, 1)],
[(0, 0), (0, 1), (0, 2), (1, 1)],
[(0, 0), (1, 0), (2, 0), (1, 1)],
[(0, 0), (1, 0), (2, 0), (1, -1)]
]
def dfs(x, y, cnt):
global ans, total
if cnt == 4:
ans = max(ans, total)
return
for i in range(16):
nx, ny = x + tetrominos[i][0][0], y + tetrominos[i][0][1]
if nx < 0 or ny < 0 or nx >= n or ny >= m:
continue
flag = True
for j in range(1, len(tetrominos[i])):
nnx, nny = x + tetrominos[i][j][0], y + tetrominos[i][j][1]
if nnx < 0 or nny < 0 or nnx >= n or nny >= m:
flag = False
break
nx, ny = nnx, nny
if flag and not visited[nx][ny]:
visited[nx][ny] = True
total += a[nx][ny]
dfs(nx, ny, cnt + 1)
total -= a[nx][ny]
visited[nx][ny] = False
if __name__ == "__main__":
n, m = map(int, sys.stdin.readline().split())
a = [list(map(int, sys.stdin.readline().split())) for _ in range(n)]
visited = [[False] * m for _ in range(n)]
ans = 0
for i in range(n):
for j in range(m):
visited[i][j] = True
total = a[i][j]
dfs(i, j, 1)
visited[i][j] = False
print(ans)
dfs를 활용한 코드이나
[0, 0, 0, 0, 0, 1, 0, 0, 0, 0]
[0, 0, 0, 0, 0, 0, 1, 0, 0, 0]
[0, 0, 0, 0, 0, 0, 0, 1, 0, 0]
[0, 0, 0, 0, 0, 0, 0, 0, 1, 0]
[0, 0, 0, 0, 0]
[1, 0, 0, 0, 0]
[0, 0, 0, 0, 0]
[1, 1, 1, 0, 0]
[0, 0, 0, 0, 0]
등의 형태도 테트로미노로 인식해버려 정답이 틀리게 나왔다.
이번엔 반대로 완성된 코드에서 T 블록을 고려하는 부분을 분석하게 만들어 봤다.
Q7)
import sys
def dfs(x, y, cnt):
global ans, total, visited
if cnt == 4:
ans = max(ans, total)
return
for dx, dy in [(-1, 0), (1, 0), (0, -1), (0, 1)]:
nx, ny = x + dx, y + dy
if nx < 0 or nx >= n or ny < 0 or ny >= m or visited[nx][ny]:
continue
if cnt == 2:
visited[nx][ny] = True
total += a[nx][ny]
dfs(x, y, cnt + 1)
total -= a[nx][ny]
visited[nx][ny] = False
visited[nx][ny] = True
total += a[nx][ny]
dfs(nx, ny, cnt + 1)
total -= a[nx][ny]
visited[nx][ny] = False
if __name__ == "__main__":
n, m = map(int, sys.stdin.readline().split())
a = [list(map(int, sys.stdin.readline().split())) for _ in range(n)]
visited = [[False] * m for _ in range(n)]
ans = 0
for i in range(n):
for j in range(m):
visited[i][j] = True
total = a[i][j]
dfs(i, j, 1)
visited[i][j] = False
print(ans)
위 코드에서 'if cnt == 2:' 조건문은 무슨 용도야?
A7)
알면서 왜 직접은 못만드냐고.. 내 질문이 잘못됐나