import sys
import csv
import re
maxInt = sys.maxsize

while True:
    try:
        csv.field_size_limit(maxInt)
        break
    except OverflowError:
        maxInt = int(maxInt/10)

         
rap = {}
rb = {}
rock = {}
pop = {}
country = {}
misc = {} 

pop_ = 0
rb_ = 0
rock_ = 0
rap_ = 0
misc_ = 0
country_ = 0

limit = 500

#reducing the dataset to 500 entries for each genre
#creating separate dictionaries for each genre
with open('song_lyrics.csv', encoding = 'cp850') as csvfile:
    reader = csv.DictReader(csvfile)
    for row in reader:
        #print(row['tag'], row['lyrics'])
        if row['tag'] == "pop" and pop_ < limit:
          pop_ += 1
          pop[(row['artist'], row['title'])] = re.sub(r"\[.*\]", "", row['lyrics'])
        elif row['tag'] == "rb" and rb_ < limit:
          rb_ += 1
          rb[(row['artist'], row['title'])] = re.sub(r"\[.*\]", "", row['lyrics'])
        elif row['tag'] == "rock" and rock_ < limit:
          rock_ += 1
          rock[(row['artist'], row['title'])] = re.sub(r"\[.*\]", "", row['lyrics'])
        elif row['tag'] == "rap" and rap_ < limit:
          rap_ += 1
          rap[(row['artist'], row['title'])] = re.sub(r"\[.*\]", "", row['lyrics'])
        elif row['tag'] == "misc" and misc_ < limit:
          misc_ += 1
          misc[(row['artist'], row['title'])] = re.sub(r"\[.*\]", "", row['lyrics'])
        elif row['tag'] == "country" and country_ < limit:
          country_ += 1
          country[(row['artist'], row['title'])] = re.sub(r"\[.*\]", "", row['lyrics'])  
        elif pop_ == limit and rb_ == limit and rap_ == limit and rock_ == limit and country_ == limit and misc_ == limit:
          break
          
          
#processing the lyrics and calculating the semantic similarity values
from sentence_transformers import SentenceTransformer, util
sentences = list(pop.values())

model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')

embed = []

for sen in sentences:
  embed.append(model.encode(sen, convert_to_tensor=True))

l = len(pop)
pop_res = [[0] * l] * l

for i in range(l):
  for j in range(l):
    if i < j:
      pop_res[i][j] = util.pytorch_cos_sim(embed[i], embed[j])


sentences = list(rb.values())

model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')

embed = []

for sen in sentences:
  embed.append(model.encode(sen, convert_to_tensor=True))

l = len(rb)
rb_res = [[0] * l] * l

for i in range(l):
  for j in range(l):
    if i < j:
      rb_res[i][j] = util.pytorch_cos_sim(embed[i], embed[j])
     
 
sentences = list(rock.values())

model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')

embed = []

for sen in sentences:
  embed.append(model.encode(sen, convert_to_tensor=True))

l = len(rock)
rock_res = [[0] * l] * l

for i in range(l):
  for j in range(l):
    if i < j:
      rock_res[i][j] = util.pytorch_cos_sim(embed[i], embed[j])


sentences = list(country.values())

model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')

embed = []

for sen in sentences:
  embed.append(model.encode(sen, convert_to_tensor=True))

l = len(country)
country_res = [[0] * l] * l

for i in range(l):
  for j in range(l):
    if i < j:
      country_res[i][j] = util.pytorch_cos_sim(embed[i], embed[j])
      

sentences = list(misc.values())

model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')

embed = []

for sen in sentences:
  embed.append(model.encode(sen, convert_to_tensor=True))

l = len(misc)
misc_res = [[0] * l] * l

for i in range(l):
  for j in range(l):
    if i < j:
      misc_res[i][j] = util.pytorch_cos_sim(embed[i], embed[j])
      

sentences = list(rap.values())

model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')

embed = []

for sen in sentences:
  embed.append(model.encode(sen, convert_to_tensor=True))

l = len(rap)
rap_res = [[0] * l] * l

for i in range(l):
  for j in range(l):
    if i < j:
      rap_res[i][j] = util.pytorch_cos_sim(embed[i], embed[j])
      
      

#calculating the average similarity value for each genre
from statistics import mean
toFloat = lambda a: a.numpy()[-1][-1]

pop_resu = []
rb_resu = []
rock_resu = []
country_resu = []
rap_resu = []
misc_resu = []

for a in pop_res:
  for b in a:
    if b != 0:
      pop_resu.append(toFloat(b))

for a in rb_res:
  for b in a:
    if b != 0:
      rb_resu.append(toFloat(b))

for a in rock_res:
  for b in a:
    if b != 0:
      rock_resu.append(toFloat(b))

for a in country_res:
  for b in a:
    if b != 0:
      country_resu.append(toFloat(b))

for a in rap_res:
  for b in a:
    if b != 0:
      rap_resu.append(toFloat(b))

for a in misc_res:
  for b in a:
    if b != 0:
      misc_resu.append(toFloat(b))
