[pytorch] 3. 모델 학습 (파이토치 학습 절차)

2023. 4. 5. 23:48·pytorch
728x90
반응형

모델을 학습을 시킨다는 것은 y = wx + b 라는 함수에서 w와 b의 적절한 값을 찾는다는 것을 의미합니다. w와 b에 임의의 값을 적용하여 시작하여 오차가 줄어들어 전역 최소점에 이를 때까지 파라미터 (w,b)를 계속 수정합니다.

 

가장 먼저 필요한 절차가 optimizer.zero_grad() 함수를 이용하여 기울기를 초기화하는 것입니다. 파이토치는 기울기 값을 계산하기 위해 loss.backward() 함수를 이용하는데, 이것을 사용하면 새로운 기울기 값이 이전 기울기 값에 누적하여 계산됩니다. 이 방법은 순환신경망(RNN) 모델을 구현할 때 효과적이지만 누적 계산이 필요하지 않는 모델에 대해서는 불필요합니다. 따라서 기울기 값에 대해 누적 계산이 필요하지 않을 때는 입력 값을 모델에 적용하기 전 optimizer.zerograd() 함수를 호출하여 미분 값이 누적되지 않게 초기화해 주어야 합니다.

 

딥러닝 학습 절차 파이토치 학습 절차 ( 모델 학습 과정)
모델, 손실 함수, 옵티마이저 정의 모델, 손실함수, 옵티마이저 정의
optimizer.zero_grad() : 전방향 학습, 기울기 초기화
전방향 학습(입력 -> 출력 계산) output = model(input) : 출력 계산
손실 함수로 출력과 정답의 차이(오차) 계산 loss = loss_fn(output, target) : 오차 계산
역전파 학습(기울기 계산) loss.backward() : 역전파 학습 (기울기 계산)
기울기 업데이트 optimizer.step() : 기울기 업데이트

다음은 loss.backward() 함수를 이용하여 기울기를 자동으로 계산합니다. loss.backward()는 배치가 반복될 때마다 오차가 중첩적으로 쌓이게 되므로 zero_grad()를 사용하여 미분 값을 0으로 초기화 합니다.

 

<모델 훈련 예시 코드>

for epoch in range(100):
	yhat = model(x_train) # 입력 -> 출력
    loss = criterion(yhat, y_train) # 오차 계산
    optimizer.zero_grad() # 기울기 초기화
    loss.backward()       # 기울기 계산
    optimizer.step()      # 기울기 업데이트

 

728x90
반응형
저작자표시 (새창열림)

'pytorch' 카테고리의 다른 글

[pytorch] CNN (합성곱 신경망)의 구조  (0) 2023.04.09
[pytorch] Dropout  (0) 2023.04.06
[pytorch] 2. 모델 파라미터(손실 함수/ 옵티마이저 / 학습률 스케줄러)  (0) 2023.04.04
[pytorch] 1. 모델 정의 (nn.Module / nn.Sequential)  (0) 2023.03.22
[pytorch] tensor(텐서) 생성/이해/조작  (0) 2023.03.19
'pytorch' 카테고리의 다른 글
  • [pytorch] CNN (합성곱 신경망)의 구조
  • [pytorch] Dropout
  • [pytorch] 2. 모델 파라미터(손실 함수/ 옵티마이저 / 학습률 스케줄러)
  • [pytorch] 1. 모델 정의 (nn.Module / nn.Sequential)
ISFP의 블로그
ISFP의 블로그
이건 첫 번째 레슨, 업무에서 마주친 문제 해결 경험 공유하기 이건 두 번째 레슨, 개인적으로 공부한 데이터/AI 지식을 기록하기 이건 세 번째 레슨, 다른 사람과 비교하지 말고 오직 어제의 나와 비교하기
  • ISFP의 블로그
    resultofeffort
    ISFP의 블로그
  • 전체
    오늘
    어제
    • 분류 전체보기 (117)
      • python (25)
      • pythonML (27)
      • Linux (0)
      • 오류Error (8)
      • information (7)
      • Deep learning (5)
      • pytorch (29)
      • 코딩테스트 (4)
      • 밑바닥 DL (4)
      • 논문 리뷰 (3)
  • 블로그 메뉴

    • 홈
    • 태그
    • 방명록
  • 링크

  • 공지사항

  • 인기 글

  • 태그

    텍스트전처리
    cnn
    토큰화
    deeplearning
    Ai
    Deep Learning
    티스토리챌린지
    오블완
    데이터분석
    자연어처리
    머신러닝
    분류
    Pandas
    Python
    인공지능
    konlpy
    pytorch
    nlp
    딥러닝
    machinelearning
  • 최근 댓글

  • 최근 글

  • hELLO· Designed By정상우.v4.10.5
ISFP의 블로그
[pytorch] 3. 모델 학습 (파이토치 학습 절차)
상단으로

티스토리툴바