2024-06-21 21:26:27 -07:00

384 lines
14 KiB
Python

"""
Use Yahoo Finance data
"""
import warnings
# Suppress FutureWarnings
warnings.simplefilter(action='ignore', category=FutureWarning)
import datetime as dt
import os
import pandas as pd
import numpy as np
# import yfinance as yf
import yahoo_fin.stock_info as si
import requests
from lxml import html
from io import StringIO
from time import sleep
WEBSITE = 'https://www.isolo.org/dokuwiki/knowledge_base/investing/watchlist'
BATCHSIZE = 20
TIMEGAP = 0.2
def fill_missing_data(df):
temp = df.ffill()
temp = temp.bfill()
return temp
def symbol_to_path(symbol, base_dir=None):
"""Return CSV file path given ticker symbol."""
if base_dir is None:
base_dir = os.environ.get("MARKET_DATA_DIR", '../data/')
return os.path.join(base_dir, "{}.csv".format(str(symbol)))
# def get_data_full(symbols, start_date, addSPY=True, colname = 'Adj Close'):
# """
# Read stock data (adjusted close) for given symbols from Yahoo Finance
# from start_date to the latest date available (usually the current date).
# """
# if addSPY and 'SPY' not in symbols: # add SPY for reference, if absent
# symbols = ['SPY'] + symbols
# df = yf.download(symbols, start = start_date)[colname]
# if len(symbols) == 1:
# df.name = symbols[0]
# df = df.to_frame()
# return df
# def get_data_full(symbols, start_date, addSPY=True, colname = 'Adj Close'):
"""
Read stock data (adjusted close) for given symbols from CSV files
from start_date to the latest date available in the CSV files.
"""
# df_temp = pd.read_csv(symbol_to_path('SPY'), index_col='Date',
# parse_dates=True, usecols=['Date', colname], na_values=['nan'])
# df_temp = df_temp.rename(columns={colname: 'SPY'})
# end_date = df_temp.index.values[-1]
# dates = pd.date_range(start_date, end_date)
# df = pd.DataFrame(index=dates)
# df = df.join(df_temp)
# df = df.dropna()
# # if addSPY and 'SPY' not in symbols: # add SPY for reference, if absent
# # symbols = ['SPY'] + symbols
# for symbol in symbols:
# df_temp = pd.read_csv(symbol_to_path(symbol), index_col='Date',
# parse_dates=True, usecols=['Date', colname], na_values=['nan'])
# df_temp = df_temp.rename(columns={colname: symbol})
# df = df.join(df_temp)
# # if symbol == 'SPY': # drop dates SPY did not trade
# # df = df.dropna(subset=["SPY"])
# if not addSPY:
# df = df[symbols]
# return df
def get_data_range(df, dates):
"""
Extract sections of the data in the dates range from the full data set
"""
df_range = pd.DataFrame(index=dates)
df_range = df_range.join(df, how='inner')
return df_range
def yf_download(symbols, start, end):
df = pd.DataFrame(columns = pd.MultiIndex(levels=[["Adj Close", "Volume"],[]], codes=[[],[]], names=["param", "tick"]))
for sym in symbols:
# tmp = si.get_data(sym, start_date=start)
tmp = si.get_data(sym, start_date=start)[["adjclose", "volume"]]
tmp.rename(columns={"adjclose": "Adj Close", "volume": "Volume"}, inplace=True)
tmp.columns = pd.MultiIndex.from_product([list(tmp.columns)] + [[sym]], names=["param", "tick"])
df = df.join(tmp, how='outer')
return df
# def get_data(symbols, dates, addSPY=True, colname = 'Adj Close'):
# """
# Read stock data (adjusted close) for given symbols from Yahoo Finance
# """
# org_sym = symbols
# sd = dates[0]
# ed = dates[-1]
# # if addSPY and 'SPY' not in symbols: # add SPY for reference, if absent
# if 'SPY' not in symbols:
# symbols = ['SPY'] + symbols
# df = yf.download(symbols, start=sd, end = ed)[colname]
# if len(symbols) == 1:
# df.name = symbols[0]
# df = df.to_frame()
# df = df.dropna(subset=['SPY'])
# df = fill_missing_data(df)
# if addSPY==False:
# # df = df.drop(columns=['SPY'])
# df = df[org_sym]
# return df
def yf_batch_download(symbols, start, end, batch_size, time_gap):
"""
download in small batches to avoid connection closure by host
Parameters
----------
symbols : list
stock symbols.
start : datetime
start date.
end : datetime
stop date.
batch_size : integer
batch size.
time_gap : float
in seconds or fraction of seconds.
Returns
-------
df : dataframe
stock price volume information.
"""
n = len(symbols)
batches = n // batch_size
df = pd.DataFrame()
for i in range(batches - 1):
tmp = yf_download(symbols[i*batch_size:(i+1)*batch_size], start, end)
df = pd.concat([df, tmp], axis=1)
sleep(time_gap)
tmp = yf_download(symbols[(batches-1)*batch_size:n], start, end)
df = pd.concat([df, tmp], axis=1)
return df
def get_price_volume(symbols, dates, addSPY=False):
"""
Read stock data (adjusted close and volume) for given symbols from local
file unless data is not in local. It only gets date from Yahoo Finance
when necessary to increase speed and reduce internet data.
It will refresh local data if the symbols are on the _refresh.csv. This
is necessary when stock splits, spins off or something else happens.
"""
# DATAFILE = "_stkdata.pickle"
# REFRESH = "_refresh.csv"
org_sym = symbols
sd = dates[0]
ed = dates[-1]
# if addSPY and 'SPY' not in symbols: # add SPY for reference, if absent
if 'SPY' not in symbols:
symbols = ['SPY'] + symbols
df = yf_batch_download(symbols, start=sd, end=ed, \
batch_size=BATCHSIZE, time_gap=TIMEGAP)
if len(symbols) == 1:
tuples = list(zip(df.columns.values.tolist(), \
[symbols[0]]*len(df.columns.values)))
df.columns = pd.MultiIndex.from_tuples(tuples, names=[None, None])
# if not os.path.exists(DATAFILE):
# df = yf_batch_download(symbols, start=sd, end=ed, \
# batch_size=BATCHSIZE, time_gap=TIMEGAP)
# if len(symbols) == 1:
# tuples = list(zip(df.columns.values.tolist(), \
# [symbols[0]]*len(df.columns.values)))
# df.columns = pd.MultiIndex.from_tuples(tuples, names=[None, None])
# else:
# df = pd.read_pickle(DATAFILE)
# exist_syms = df["Adj Close"].columns.values.tolist()
# if os.path.exists(REFRESH):
# try:
# refresh_df = pd.read_csv(REFRESH, header=None)
# refresh_syms = refresh_df.values.tolist()
# refresh_syms = [x for sublist in refresh_syms for x in sublist]
# remove_syms = [x for x in exist_syms if x in refresh_syms]
# if remove_syms:
# df.drop(columns=remove_syms, axis=1, level=1, inplace=True)
# exist_syms = [x for x in exist_syms if x not in refresh_syms]
# except:
# pass
exist_syms = []
last_day = pd.to_datetime(df.index.values[-1])
first_day = pd.to_datetime(df.index.values[0])
intersect_syms = list(set(org_sym) & set(exist_syms))
# reduce df to only contain intersect_syms
df = df.loc[:, (slice(None), intersect_syms)]
if sd < first_day:
# fill gap from online
tmp_df = yf_batch_download(intersect_syms, start=sd, end=first_day, \
batch_size=BATCHSIZE, time_gap=TIMEGAP)
df = pd.concat([tmp_df, df])
if ed >= last_day:
# fill gap from online incl last two days to get mkt close data
if ed.date() == last_day.date():
tmp_df = yf_batch_download(intersect_syms, start=ed, end=ed, \
batch_size=BATCHSIZE, time_gap=TIMEGAP)
else:
tmp_df = yf_batch_download(intersect_syms, start=last_day, end=ed, \
batch_size=BATCHSIZE, time_gap=TIMEGAP)
df = pd.concat([df[:-1], tmp_df])
# get data online when new stks were added
new_stks = np.setdiff1d(symbols, exist_syms).tolist()
if not new_stks == []:
tmp_df = yf_batch_download(new_stks, start=sd, end=ed, \
batch_size=BATCHSIZE, time_gap=TIMEGAP)
if len(new_stks) == 1:
tuples = list(zip(tmp_df.columns.values.tolist(), \
[new_stks[0]]*len(tmp_df.columns.values)))
tmp_df.columns = pd.MultiIndex.from_tuples(tuples, names=[None, None])
df = df.join(tmp_df)
# df.to_pickle(DATAFILE) # save to local, overwrite existing file
# if os.path.exists(REFRESH):
# with open(REFRESH, 'w'):
# pass
df = df.dropna(subset=[('Adj Close', 'SPY')])
price = df['Adj Close']
price = fill_missing_data(price)
volume = df['Volume']
volume = volume.fillna(0)
# if len(symbols) == 1:
# price.name = symbols[0]
# volume.name = symbols[0]
# price = price.to_frame()
# volume = volume.to_frame()
if addSPY==False:
price = price[org_sym]
volume = volume[org_sym]
return price, volume
# def get_price_volume_online(symbols, dates, addSPY=False):
# """
# Read stock data (adjusted close and volume) for given symbols from Yahoo
# Finance
# """
# org_sym = symbols
# sd = dates[0]
# ed = dates[-1]
# # if addSPY and 'SPY' not in symbols: # add SPY for reference, if absent
# if 'SPY' not in symbols:
# symbols = ['SPY'] + symbols
# df = yf.download(symbols, start=sd, end = ed)
# if len(symbols) == 1:
# df = df.dropna(subset = ['Adj Close'])
# else:
# df = df.dropna(subset=[('Adj Close', 'SPY')])
# price = df['Adj Close']
# price = fill_missing_data(price)
# volume = df['Volume']
# volume = volume.fillna(0)
# if len(symbols) == 1:
# price.name = symbols[0]
# volume.name = symbols[0]
# price = price.to_frame()
# volume = volume.to_frame()
# if addSPY==False:
# price = price[org_sym]
# volume = volume[org_sym]
# return price, volume
def get_watchlist(website: str = WEBSITE):
page = requests.get(WEBSITE)
# page = requests.get(WEBSITE, verify=False) # skip certificate check for https
tree = html.fromstring(page.content)
watchlist = tree.xpath('//*[@id="dokuwiki__content"]/div[1]/div/div[3]/div/pre/text()')[0]
file_name = StringIO(watchlist)
df = pd.read_csv(file_name, index_col = 'Symbol',
comment = '#', na_filter=False)
return df
# def get_watchlist(file_name: str = 'watchlist.csv'):
# df = pd.read_csv(file_name, index_col = 'Symbol',
# comment = '#', na_filter=False)
# return df
# def get_data(symbols, dates, addSPY=True, colname = 'Adj Close'):
# """
# Read stock data (adjusted close) for given symbols from CSV files.
# (done) TODO: there are nan values in the data when addSPY=False is passed. The
# strategy should be using SPY to clean the data first including fill
# forward and fill backward, then to drop the SPY if addSPY=False
# """
# org_sym = symbols
# df = pd.DataFrame(index=dates)
# # if addSPY and 'SPY' not in symbols: # add SPY for reference, if absent
# # symbols = ['SPY'] + symbols
# if 'SPY' not in symbols:
# symbols = ['SPY'] + symbols
# for symbol in symbols:
# df_temp = pd.read_csv(symbol_to_path(symbol), index_col='Date',
# parse_dates=True, usecols=['Date', colname], na_values=['nan'])
# df_temp = df_temp.rename(columns={colname: symbol})
# df = df.join(df_temp)
# if symbol == 'SPY': # drop dates SPY did not trade
# df = df.dropna(subset=["SPY"])
# # fill missing data
# df = fill_missing_data(df)
# if addSPY == False: # drop SPY
# # df = df.drop(columns=['SPY'])
# df = df[org_sym]
# return df
def plot_data(df, axs=[], title=[], xlabel='', ylabel=''):
"""Plot stock prices with a custom title and meaningful axis labels."""
if axs == []:
ax = df.plot(title = title)
else:
ax = df.plot(ax=axs, title=title)
ax.set_xlabel(xlabel)
ax.set_ylabel(ylabel)
ax.grid()
# def plot_data(df, title=[], xlabel='', ylabel=''):
# import matplotlib.pyplot as plt
# """Plot stock prices with a custom title and meaningful axis labels."""
# ax = df.plot(title=title, fontsize=12, figsize=(10, 7))
# ax.set_xlabel(xlabel)
# ax.set_ylabel(ylabel)
# plt.grid()
# plt.show()
def get_orders_data_file(basefilename):
return open(os.path.join(os.environ.get("ORDERS_DATA_DIR",'orders/'),basefilename))
def get_learner_data_file(basefilename):
return open(os.path.join(os.environ.get("LEARNER_DATA_DIR",'Data/'),basefilename),'r')
def get_robot_world_file(basefilename):
return open(os.path.join(os.environ.get("ROBOT_WORLDS_DIR",'testworlds/'),basefilename))
def test_code():
symbol = ['GOOG', 'AMZN']
# lookback years
lb_year = 0.08
ed = dt.datetime.today()
sd = ed - dt.timedelta(days = 365 * lb_year + 1)
# If ed or sd falls on to a non-trading day, you might get warnings saying
# "No data found for this date range, symbol may be delisted". This is
# normal behavior.
prices, volume = get_price_volume(symbol, pd.date_range(sd, ed), addSPY=False)
if __name__ == '__main__':
test_code()