384 lines
14 KiB
Python
384 lines
14 KiB
Python
"""
|
|
Use Yahoo Finance data
|
|
"""
|
|
|
|
import warnings
|
|
|
|
# Suppress FutureWarnings
|
|
warnings.simplefilter(action='ignore', category=FutureWarning)
|
|
|
|
import datetime as dt
|
|
import os
|
|
import pandas as pd
|
|
import numpy as np
|
|
# import yfinance as yf
|
|
import yahoo_fin.stock_info as si
|
|
import requests
|
|
from lxml import html
|
|
from io import StringIO
|
|
from time import sleep
|
|
|
|
WEBSITE = 'https://www.isolo.org/dokuwiki/knowledge_base/investing/watchlist'
|
|
BATCHSIZE = 20
|
|
TIMEGAP = 0.2
|
|
|
|
def fill_missing_data(df):
|
|
temp = df.ffill()
|
|
temp = temp.bfill()
|
|
return temp
|
|
|
|
def symbol_to_path(symbol, base_dir=None):
|
|
"""Return CSV file path given ticker symbol."""
|
|
if base_dir is None:
|
|
base_dir = os.environ.get("MARKET_DATA_DIR", '../data/')
|
|
return os.path.join(base_dir, "{}.csv".format(str(symbol)))
|
|
|
|
# def get_data_full(symbols, start_date, addSPY=True, colname = 'Adj Close'):
|
|
# """
|
|
# Read stock data (adjusted close) for given symbols from Yahoo Finance
|
|
# from start_date to the latest date available (usually the current date).
|
|
# """
|
|
# if addSPY and 'SPY' not in symbols: # add SPY for reference, if absent
|
|
# symbols = ['SPY'] + symbols
|
|
|
|
# df = yf.download(symbols, start = start_date)[colname]
|
|
# if len(symbols) == 1:
|
|
# df.name = symbols[0]
|
|
# df = df.to_frame()
|
|
# return df
|
|
|
|
# def get_data_full(symbols, start_date, addSPY=True, colname = 'Adj Close'):
|
|
"""
|
|
Read stock data (adjusted close) for given symbols from CSV files
|
|
from start_date to the latest date available in the CSV files.
|
|
"""
|
|
# df_temp = pd.read_csv(symbol_to_path('SPY'), index_col='Date',
|
|
# parse_dates=True, usecols=['Date', colname], na_values=['nan'])
|
|
# df_temp = df_temp.rename(columns={colname: 'SPY'})
|
|
# end_date = df_temp.index.values[-1]
|
|
# dates = pd.date_range(start_date, end_date)
|
|
# df = pd.DataFrame(index=dates)
|
|
# df = df.join(df_temp)
|
|
# df = df.dropna()
|
|
# # if addSPY and 'SPY' not in symbols: # add SPY for reference, if absent
|
|
# # symbols = ['SPY'] + symbols
|
|
# for symbol in symbols:
|
|
# df_temp = pd.read_csv(symbol_to_path(symbol), index_col='Date',
|
|
# parse_dates=True, usecols=['Date', colname], na_values=['nan'])
|
|
# df_temp = df_temp.rename(columns={colname: symbol})
|
|
# df = df.join(df_temp)
|
|
# # if symbol == 'SPY': # drop dates SPY did not trade
|
|
# # df = df.dropna(subset=["SPY"])
|
|
# if not addSPY:
|
|
# df = df[symbols]
|
|
# return df
|
|
|
|
def get_data_range(df, dates):
|
|
"""
|
|
Extract sections of the data in the dates range from the full data set
|
|
"""
|
|
df_range = pd.DataFrame(index=dates)
|
|
df_range = df_range.join(df, how='inner')
|
|
return df_range
|
|
|
|
def yf_download(symbols, start, end):
|
|
df = pd.DataFrame(columns = pd.MultiIndex(levels=[["Adj Close", "Volume"],[]], codes=[[],[]], names=["adjclose", "volume"]))
|
|
for sym in symbols:
|
|
# tmp = si.get_data(sym, start_date=start)
|
|
tmp = si.get_data(sym, start_date=start)[["adjclose", "volume"]]
|
|
tuples = list(zip(tmp.columns.values.tolist(), \
|
|
[symbols[0]]*len(tmp.columns.values)))
|
|
tmp.columns = pd.MultiIndex.from_tuples(tuples, names=[None, None])
|
|
df = df.join(tmp, how='outer')
|
|
|
|
return df
|
|
|
|
# def get_data(symbols, dates, addSPY=True, colname = 'Adj Close'):
|
|
# """
|
|
# Read stock data (adjusted close) for given symbols from Yahoo Finance
|
|
# """
|
|
# org_sym = symbols
|
|
# sd = dates[0]
|
|
# ed = dates[-1]
|
|
# # if addSPY and 'SPY' not in symbols: # add SPY for reference, if absent
|
|
# if 'SPY' not in symbols:
|
|
# symbols = ['SPY'] + symbols
|
|
# df = yf.download(symbols, start=sd, end = ed)[colname]
|
|
# if len(symbols) == 1:
|
|
# df.name = symbols[0]
|
|
# df = df.to_frame()
|
|
|
|
# df = df.dropna(subset=['SPY'])
|
|
# df = fill_missing_data(df)
|
|
|
|
# if addSPY==False:
|
|
# # df = df.drop(columns=['SPY'])
|
|
# df = df[org_sym]
|
|
|
|
# return df
|
|
|
|
def yf_batch_download(symbols, start, end, batch_size, time_gap):
|
|
"""
|
|
download in small batches to avoid connection closure by host
|
|
|
|
Parameters
|
|
----------
|
|
symbols : list
|
|
stock symbols.
|
|
start : datetime
|
|
start date.
|
|
end : datetime
|
|
stop date.
|
|
batch_size : integer
|
|
batch size.
|
|
time_gap : float
|
|
in seconds or fraction of seconds.
|
|
|
|
Returns
|
|
-------
|
|
df : dataframe
|
|
stock price volume information.
|
|
|
|
"""
|
|
n = len(symbols)
|
|
batches = n // batch_size
|
|
df = pd.DataFrame()
|
|
for i in range(batches - 1):
|
|
tmp = yf_download(symbols[i*batch_size:(i+1)*batch_size], start, end)
|
|
df = pd.concat([df, tmp], axis=1)
|
|
sleep(time_gap)
|
|
tmp = yf_download(symbols[(batches-1)*batch_size:n], start, end)
|
|
df = pd.concat([df, tmp], axis=1)
|
|
|
|
return df
|
|
|
|
def get_price_volume(symbols, dates, addSPY=False):
|
|
"""
|
|
Read stock data (adjusted close and volume) for given symbols from local
|
|
file unless data is not in local. It only gets date from Yahoo Finance
|
|
when necessary to increase speed and reduce internet data.
|
|
|
|
It will refresh local data if the symbols are on the _refresh.csv. This
|
|
is necessary when stock splits, spins off or something else happens.
|
|
"""
|
|
# DATAFILE = "_stkdata.pickle"
|
|
# REFRESH = "_refresh.csv"
|
|
org_sym = symbols
|
|
sd = dates[0]
|
|
ed = dates[-1]
|
|
# if addSPY and 'SPY' not in symbols: # add SPY for reference, if absent
|
|
if 'SPY' not in symbols:
|
|
symbols = ['SPY'] + symbols
|
|
|
|
df = yf_batch_download(symbols, start=sd, end=ed, \
|
|
batch_size=BATCHSIZE, time_gap=TIMEGAP)
|
|
if len(symbols) == 1:
|
|
tuples = list(zip(df.columns.values.tolist(), \
|
|
[symbols[0]]*len(df.columns.values)))
|
|
df.columns = pd.MultiIndex.from_tuples(tuples, names=[None, None])
|
|
|
|
# if not os.path.exists(DATAFILE):
|
|
# df = yf_batch_download(symbols, start=sd, end=ed, \
|
|
# batch_size=BATCHSIZE, time_gap=TIMEGAP)
|
|
# if len(symbols) == 1:
|
|
# tuples = list(zip(df.columns.values.tolist(), \
|
|
# [symbols[0]]*len(df.columns.values)))
|
|
# df.columns = pd.MultiIndex.from_tuples(tuples, names=[None, None])
|
|
# else:
|
|
# df = pd.read_pickle(DATAFILE)
|
|
# exist_syms = df["Adj Close"].columns.values.tolist()
|
|
# if os.path.exists(REFRESH):
|
|
# try:
|
|
# refresh_df = pd.read_csv(REFRESH, header=None)
|
|
# refresh_syms = refresh_df.values.tolist()
|
|
# refresh_syms = [x for sublist in refresh_syms for x in sublist]
|
|
# remove_syms = [x for x in exist_syms if x in refresh_syms]
|
|
# if remove_syms:
|
|
# df.drop(columns=remove_syms, axis=1, level=1, inplace=True)
|
|
# exist_syms = [x for x in exist_syms if x not in refresh_syms]
|
|
# except:
|
|
# pass
|
|
|
|
exist_syms = []
|
|
|
|
last_day = pd.to_datetime(df.index.values[-1])
|
|
first_day = pd.to_datetime(df.index.values[0])
|
|
intersect_syms = list(set(org_sym) & set(exist_syms))
|
|
# reduce df to only contain intersect_syms
|
|
df = df.loc[:, (slice(None), intersect_syms)]
|
|
|
|
if sd < first_day:
|
|
# fill gap from online
|
|
tmp_df = yf_batch_download(intersect_syms, start=sd, end=first_day, \
|
|
batch_size=BATCHSIZE, time_gap=TIMEGAP)
|
|
df = pd.concat([tmp_df, df])
|
|
|
|
if ed >= last_day:
|
|
# fill gap from online incl last two days to get mkt close data
|
|
if ed.date() == last_day.date():
|
|
tmp_df = yf_batch_download(intersect_syms, start=ed, end=ed, \
|
|
batch_size=BATCHSIZE, time_gap=TIMEGAP)
|
|
else:
|
|
tmp_df = yf_batch_download(intersect_syms, start=last_day, end=ed, \
|
|
batch_size=BATCHSIZE, time_gap=TIMEGAP)
|
|
df = pd.concat([df[:-1], tmp_df])
|
|
|
|
# get data online when new stks were added
|
|
new_stks = np.setdiff1d(symbols, exist_syms).tolist()
|
|
if not new_stks == []:
|
|
tmp_df = yf_batch_download(new_stks, start=sd, end=ed, \
|
|
batch_size=BATCHSIZE, time_gap=TIMEGAP)
|
|
if len(new_stks) == 1:
|
|
tuples = list(zip(tmp_df.columns.values.tolist(), \
|
|
[new_stks[0]]*len(tmp_df.columns.values)))
|
|
tmp_df.columns = pd.MultiIndex.from_tuples(tuples, names=[None, None])
|
|
df = df.join(tmp_df)
|
|
|
|
# df.to_pickle(DATAFILE) # save to local, overwrite existing file
|
|
# if os.path.exists(REFRESH):
|
|
# with open(REFRESH, 'w'):
|
|
# pass
|
|
|
|
df = df.dropna(subset=[('Adj Close', 'SPY')])
|
|
price = df['Adj Close']
|
|
price = fill_missing_data(price)
|
|
volume = df['Volume']
|
|
volume = volume.fillna(0)
|
|
|
|
# if len(symbols) == 1:
|
|
# price.name = symbols[0]
|
|
# volume.name = symbols[0]
|
|
# price = price.to_frame()
|
|
# volume = volume.to_frame()
|
|
|
|
if addSPY==False:
|
|
price = price[org_sym]
|
|
volume = volume[org_sym]
|
|
|
|
return price, volume
|
|
|
|
|
|
# def get_price_volume_online(symbols, dates, addSPY=False):
|
|
# """
|
|
# Read stock data (adjusted close and volume) for given symbols from Yahoo
|
|
# Finance
|
|
# """
|
|
# org_sym = symbols
|
|
# sd = dates[0]
|
|
# ed = dates[-1]
|
|
# # if addSPY and 'SPY' not in symbols: # add SPY for reference, if absent
|
|
# if 'SPY' not in symbols:
|
|
# symbols = ['SPY'] + symbols
|
|
# df = yf.download(symbols, start=sd, end = ed)
|
|
# if len(symbols) == 1:
|
|
# df = df.dropna(subset = ['Adj Close'])
|
|
# else:
|
|
# df = df.dropna(subset=[('Adj Close', 'SPY')])
|
|
# price = df['Adj Close']
|
|
# price = fill_missing_data(price)
|
|
# volume = df['Volume']
|
|
# volume = volume.fillna(0)
|
|
|
|
# if len(symbols) == 1:
|
|
# price.name = symbols[0]
|
|
# volume.name = symbols[0]
|
|
# price = price.to_frame()
|
|
# volume = volume.to_frame()
|
|
|
|
# if addSPY==False:
|
|
# price = price[org_sym]
|
|
# volume = volume[org_sym]
|
|
|
|
# return price, volume
|
|
|
|
def get_watchlist(website: str = WEBSITE):
|
|
page = requests.get(WEBSITE)
|
|
# page = requests.get(WEBSITE, verify=False) # skip certificate check for https
|
|
tree = html.fromstring(page.content)
|
|
watchlist = tree.xpath('//*[@id="dokuwiki__content"]/div[1]/div/div[3]/div/pre/text()')[0]
|
|
file_name = StringIO(watchlist)
|
|
df = pd.read_csv(file_name, index_col = 'Symbol',
|
|
comment = '#', na_filter=False)
|
|
return df
|
|
|
|
# def get_watchlist(file_name: str = 'watchlist.csv'):
|
|
# df = pd.read_csv(file_name, index_col = 'Symbol',
|
|
# comment = '#', na_filter=False)
|
|
# return df
|
|
|
|
# def get_data(symbols, dates, addSPY=True, colname = 'Adj Close'):
|
|
# """
|
|
# Read stock data (adjusted close) for given symbols from CSV files.
|
|
|
|
# (done) TODO: there are nan values in the data when addSPY=False is passed. The
|
|
# strategy should be using SPY to clean the data first including fill
|
|
# forward and fill backward, then to drop the SPY if addSPY=False
|
|
# """
|
|
# org_sym = symbols
|
|
# df = pd.DataFrame(index=dates)
|
|
# # if addSPY and 'SPY' not in symbols: # add SPY for reference, if absent
|
|
# # symbols = ['SPY'] + symbols
|
|
# if 'SPY' not in symbols:
|
|
# symbols = ['SPY'] + symbols
|
|
# for symbol in symbols:
|
|
# df_temp = pd.read_csv(symbol_to_path(symbol), index_col='Date',
|
|
# parse_dates=True, usecols=['Date', colname], na_values=['nan'])
|
|
# df_temp = df_temp.rename(columns={colname: symbol})
|
|
# df = df.join(df_temp)
|
|
# if symbol == 'SPY': # drop dates SPY did not trade
|
|
# df = df.dropna(subset=["SPY"])
|
|
# # fill missing data
|
|
# df = fill_missing_data(df)
|
|
# if addSPY == False: # drop SPY
|
|
# # df = df.drop(columns=['SPY'])
|
|
# df = df[org_sym]
|
|
|
|
# return df
|
|
|
|
|
|
def plot_data(df, axs=[], title=[], xlabel='', ylabel=''):
|
|
|
|
"""Plot stock prices with a custom title and meaningful axis labels."""
|
|
if axs == []:
|
|
ax = df.plot(title = title)
|
|
else:
|
|
ax = df.plot(ax=axs, title=title)
|
|
ax.set_xlabel(xlabel)
|
|
ax.set_ylabel(ylabel)
|
|
ax.grid()
|
|
|
|
|
|
# def plot_data(df, title=[], xlabel='', ylabel=''):
|
|
# import matplotlib.pyplot as plt
|
|
# """Plot stock prices with a custom title and meaningful axis labels."""
|
|
# ax = df.plot(title=title, fontsize=12, figsize=(10, 7))
|
|
# ax.set_xlabel(xlabel)
|
|
# ax.set_ylabel(ylabel)
|
|
# plt.grid()
|
|
# plt.show()
|
|
|
|
def get_orders_data_file(basefilename):
|
|
return open(os.path.join(os.environ.get("ORDERS_DATA_DIR",'orders/'),basefilename))
|
|
|
|
def get_learner_data_file(basefilename):
|
|
return open(os.path.join(os.environ.get("LEARNER_DATA_DIR",'Data/'),basefilename),'r')
|
|
|
|
def get_robot_world_file(basefilename):
|
|
return open(os.path.join(os.environ.get("ROBOT_WORLDS_DIR",'testworlds/'),basefilename))
|
|
|
|
|
|
def test_code():
|
|
|
|
symbol = ['GOOG', 'AMZN']
|
|
# lookback years
|
|
lb_year = 0.08
|
|
ed = dt.datetime.today()
|
|
sd = ed - dt.timedelta(days = 365 * lb_year + 1)
|
|
# If ed or sd falls on to a non-trading day, you might get warnings saying
|
|
# "No data found for this date range, symbol may be delisted". This is
|
|
# normal behavior.
|
|
prices, volume = get_price_volume(symbol, pd.date_range(sd, ed), addSPY=False)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
test_code() |