python金融数据分析和可视化--04利用Baostock获取股票数据

04利用Baostock获取股票数据

1. Baostock平台介绍

Baostock是一个免费、开源的证券数据平台(无需注册)。

提供大量准确、完整的证券历史行情数据、上市公司财务数据等。
通过python API获取证券数据信息,满足量化交易投资者、数量金融爱好者、计量经济从业者数据需求。
返回的数据格式:pandas DataFrame类型,以便于用pandas/NumPy/Matplotlib进行数据分析和可视化。
支持语言:目前版本BaoStock.com目前只支持Python3.5及以上(暂不支持python 2.x)。

2. 下载安装

使用国内源安装:

pip install baostock -i https://pypi.tuna.tsinghua.edu.cn/simple/ --trusted-host pypi.tuna.tsinghua.edu.cn

3. 获取股票数据

import baostock as bs
import pandas as pd

#### 登陆系统 ####
lg = bs.login()
# 显示登陆返回信息
print('login respond error_code:'+lg.error_code)
print('login respond  error_msg:'+lg.error_msg)

#### 获取沪深A股历史K线数据 ####
# 详细指标参数,参见“历史行情指标参数”章节;“分钟线”参数与“日线”参数不同。“分钟线”不包含指数。
# 分钟线指标:date,time,code,open,high,low,close,volume,amount,adjustflag
# 周月线指标:date,code,open,high,low,close,volume,amount,adjustflag,turn,pctChg
rs = bs.query_history_k_data_plus("sh.600000",
    "date,code,open,high,low,close,preclose,volume,amount,adjustflag,turn,tradestatus,pctChg,isST",
    start_date='2017-07-01', end_date='2017-12-31',
    frequency="d", adjustflag="3")
print('query_history_k_data_plus respond error_code:'+rs.error_code)
print('query_history_k_data_plus respond  error_msg:'+rs.error_msg)

#### 打印结果集 ####
data_list = []
while (rs.error_code == '0') & rs.next():
    # 获取一条记录,将记录合并在一起
    data_list.append(rs.get_row_data())
result = pd.DataFrame(data_list, columns=rs.fields)

#### 结果集输出到csv文件 ####   
result.to_csv("D:\\history_A_stock_k_data.csv", index=False)
print(result)

#### 登出系统 ####
bs.logout()

4. 获取股票数据本地存储

import baostock as bs
import pandas as pd
import datetime
import sys


def get_stock_list(date=None):
    """
    获取指定日期的A股代码列表

    若参数date为空,则返回最近1个交易日的A股代码列表
    若参数date不为空,且为交易日,则返回date当日的A股代码列表
    若参数date不为空,但不为交易日,则打印提示非交易日信息,程序退出

    :param date: 日期
    :return: A股代码的列表
    """

    bs.login()

    stock_df = bs.query_all_stock(date).get_data()
    print(stock_df)

    # 如果获取数据长度为0,表示日期date非交易日
    if 0 == len(stock_df):

        # 如果设置了参数date,则打印信息提示date为非交易日
        if date is not None:
            print('当前选择日期为非交易日或尚无交易数据,请设置date为历史某交易日日期')
            sys.exit(0)

        # 未设置参数date,则向历史查找最近的交易日,当获取股票数据长度非0时,即找到最近交易日
        delta = 1
        while 0 == len(stock_df):
            stock_df = bs.query_all_stock(datetime.date.today() - datetime.timedelta(days=delta)).get_data()
            delta += 1

    bs.logout()

    # 筛选股票数据,上证和深证股票代码在sh.600000与sz.39900之间
    stock_df = stock_df[(stock_df['code'] >= 'sh.600000') & (stock_df['code'] < 'sz.399000')]

    # 返回股票列表
    return stock_df['code'].tolist()


def download_stock(code, start, end, freq="d", adjust="2"):
    if freq == "d":
        fields = "date,open,high,low,close,preclose,volume,amount,turn,pctChg"
    elif freq in ["m", "w"]:
        fields = "date,open,high,low,close,volume,amount,turn,pctChg"
    else:
        fields = "time, code,open,high,low,close,volume,amount"
    rs = bs.query_history_k_data_plus(
        code,
        fields,
        start_date=start, end_date=end,
        frequency=freq, adjustflag=adjust)
    # 打印结果集
    data_list = []
    while (rs.error_code == '0') & rs.next():
        # 获取一条记录,将记录合并在一起
        data_list.append(rs.get_row_data())
    result = pd.DataFrame(data_list, columns=rs.fields)

    # 结果集输出到csv文件
    if freq in ["5", "15", "30", "60"]:
        result["time"] = [t[:-3] for t in result["time"]]
        result["time"] = pd.to_datetime(result["time"])
        result = result.loc[:, ['time', 'open', 'high', 'low', 'close', 'volume', 'amount']]
        result.rename(columns={'time': 'datetime'}, inplace=True)
        result.set_index("datetime", drop=True, inplace=True)
        result.to_csv(
            "I:\\baostock\\stock_datas\\stock_download\\minute\\" + code + "_" + freq + ".csv")
    elif freq == "d":
        result.set_index("date", drop=True, inplace=True)
        result.to_csv("I:\\baostock\\stock_datas\\stock_download\\day\\" + code + ".csv")
    elif freq == "m":
        result.set_index("date", drop=True, inplace=True)
        result.to_csv(
            "I:\\baostock\\stock_datas\\stock_download\\month\\" + code + ".csv")
    elif freq == "w":
        result.set_index("date", drop=True, inplace=True)
        result.to_csv(
            "I:\\baostock\\stock_datas\\stock_download\\week\\" + code + ".csv")
    else:
        print("freq 错误")


if __name__ == "__main__":
    stockList = 'stock_list'
    start = "1990-01-01"
    end = datetime.datetime.today().strftime("%Y%m%d")
    list_data = get_stock_list()


    bs.login()
    for freq in ["d", "w", "m"]:
        for i in range(0, len(list_data)):
            download_stock(code=list_data[i], start=start, end=end, freq=freq)
            print(li1[i])

    for freq in ["30", "60"]:
        for i in range(0, len(list_data)):
            download_stock(code=list_data[i], start=start, end=end, freq=freq)
            print(li1[i])
    bs.logout()

5.将股票数据存到mysql数据库中

"""
date:20230318
将CSV文件写入到MySQL中
"""
import pandas as pd
from sqlalchemy import create_engine


def connect_db(db):
    engine = create_engine('mysql+pymysql://hao:671010@localhost:3306/{}?charset=utf8'.format(db))
    return engine


def get_all_codes():
    # 登陆系统
    bs.login()

    # 获取证券信息
    rs = bs.query_all_stock(day=None)

    # 打印结果集
    data_list = []
    while (rs.error_code == '0') & rs.next():
        # 获取一条记录,将记录合并在一起
        data_list.append(rs.get_row_data())
    result = pd.DataFrame(data_list, columns=rs.fields)

    # 结果集输出到csv文件
    result.to_csv("I:\\baostock\\stock_datas\\day\\all_stock.csv", index=False)
    print(result)

    # 登出系统
    bs.logout()
def create_stock(bscode, db, code, date):
    if date in ['day', 'week', 'month']:
        # 读取本地CSV文件
        df = pd.read_csv(
            'I:\\baostock\\stock_datas\\' + date + '\\{}.csv'.format(
                bscode))
    elif date in ['5', '15', '30', '60']:
        # 读取本地CSV文件
        df = pd.read_csv(
            'I:\\baostock\\stock_datas\\minute\\' + bscode + '_{}.csv'.format(
                date))
    else:
        print("date错误")
    engine = connect_db(db)
    # name='stocklist'全部小写否则会报错
    df.to_sql(name='bs_' + code + '_{}'.format(date), con=engine, index=False, if_exists='replace')


def read_csv(code):
    df0 = pd.read_csv('I:\\baostock\stock_datas\\day\\{}.csv'.format(code))
    # 筛选股票数据,上证和深证股票代码在sh.600000与sz.39900之间
    df = df0[(df0['code'] >= 'sh.600000') & (df0['code'] < 'sz.399000')]
    return df


if __name__=="__main__":
    get_all_codes()
    stockDB = 'baostock_db'
    stockList = 'all_stock'
    create_stock(bscode=stockList, db=stockDB, code=stockList, date="day")
    df1 = read_csv(stockList)
    li1 = list(df1['code'])
    # print(li1)
    for date in ["day", "week", "month", "5", "15", "30", "60"]:
        for i in range(0, len(li1)):
            codeStock = li1[i].lstrip('shzbj.')
            create_stock(bscode=li1[i], db=stockDB, code=codeStock, date=date)
            print(codeStock)