-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathembedding.py
More file actions
22 lines (20 loc) · 947 Bytes
/
embedding.py
File metadata and controls
22 lines (20 loc) · 947 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import math
import torch
import torch.nn as nn
class word_embedding(nn.Module):
def __init__(self, vocab_len, max_seq_len : int, d_model):
super(word_embedding, self).__init__()
self.max_seq_len = max_seq_len
self.d_model = d_model
self.embedding = torch.nn.Embedding(vocab_len, d_model)
pos = torch.arange(max_seq_len).unsqueeze(1) # Shape: (seq_len, 1)
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
pe = torch.zeros(max_seq_len, d_model)
pe[:, 0::2] = torch.sin(pos * div_term)
pe[:, 1::2] = torch.cos(pos * div_term)
self.register_buffer('pe', pe)
def get_embedding(self, tokens):
tokens = tokens.long().to(self.embedding.weight.device)
word_embedding = self.embedding(tokens) * math.sqrt(self.d_model)
word_embedding = word_embedding + self.pe
return word_embedding